Skip to content

Commit

Permalink
chore: reformat code, fix typing, sort imports
Browse files Browse the repository at this point in the history
  • Loading branch information
maugde committed Jan 7, 2025
1 parent 6ee154b commit cb0cd7c
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 36 deletions.
17 changes: 5 additions & 12 deletions antarest/login/utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,22 @@
import typing as t
from contextvars import ContextVar
from optparse import Option
from typing import Optional

from starlette.requests import Request

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp
from typing_extensions import override

from antarest.core.jwt import JWTUser
from antarest.core.serialization import from_json
from antarest.fastapi_jwt_auth import AuthJWT

_current_user: ContextVar[t.Optional[AuthJWT]] = ContextVar("_current_user", default=None)


class CurrentUserMiddleware(BaseHTTPMiddleware):
def __init__(self, app: t.Optional[ASGIApp]) -> None:
super().__init__(app)

async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
@override
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
global _current_user
auth_jwt = AuthJWT(Request(request.scope))
_current_user.set(auth_jwt)
Expand Down
2 changes: 1 addition & 1 deletion antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
from antarest.core.swagger import customize_openapi
from antarest.core.tasks.model import cancel_orphan_tasks
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
from antarest.login.utils import CurrentUserMiddleware
from antarest.core.utils.utils import get_local_path
from antarest.core.utils.web import tags_metadata
from antarest.fastapi_jwt_auth import AuthJWT
from antarest.front import add_front_app
from antarest.login.auth import Auth, JwtSettings
from antarest.login.model import init_admin_user
from antarest.login.utils import CurrentUserMiddleware
from antarest.matrixstore.matrix_garbage_collector import MatrixGarbageCollector
from antarest.service_creator import SESSION_ARGS, Module, create_services, init_db_engine
from antarest.singleton_services import start_all_services
Expand Down
12 changes: 6 additions & 6 deletions antarest/study/business/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@


def execute_or_add_commands(
study: Study,
file_study: FileStudy,
commands: t.Sequence[ICommand],
storage_service: StudyStorageService,
listener: t.Optional[ICommandListener] = None,
jwt_user: JWTUser = None,
study: Study,
file_study: FileStudy,
commands: t.Sequence[ICommand],
storage_service: StudyStorageService,
listener: t.Optional[ICommandListener] = None,
jwt_user: t.Optional[JWTUser] = None,
) -> None:
# get current user if not in session, otherwise get session user
current_user = get_current_user() or jwt_user
Expand Down
11 changes: 2 additions & 9 deletions antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,7 @@ def _generate_timeseries(self, notifier: ITaskNotifier) -> None:
command = GenerateThermalClusterTimeSeries(
command_context=command_context, study_version=file_study.config.version
)
execute_or_add_commands(
study,
file_study,
[command],
self.storage_service,
listener,
self.jwt_user
)
execute_or_add_commands(study, file_study, [command], self.storage_service, listener, self.jwt_user)

if isinstance(study, VariantStudy):
# In this case we only added the command to the list.
Expand Down Expand Up @@ -2590,7 +2583,7 @@ def generate_timeseries(self, study: Study, params: RequestParameters) -> str:
repository=self.repository,
storage_service=self.storage_service,
event_bus=self.event_bus,
jwt_user=params.user
jwt_user=params.user, # type: ignore
)

return self.task_service.add_task(
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def test_ts_generation_task(
repository=study_service.repository,
storage_service=study_service.storage_service,
event_bus=study_service.event_bus,
jwt_user=jwt_user
jwt_user=jwt_user,
)

task_id = study_service.task_service.add_task(
Expand Down
8 changes: 6 additions & 2 deletions tests/storage/business/test_arealink_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService):
with patch("antarest.study.business.utils.get_current_user", return_value=DEFAULT_ADMIN_USER):
area_manager.create_area(study, AreaCreationDTO(name="test", type=AreaType.AREA))
assert len(empty_study.config.areas.keys()) == 1
assert json.loads((empty_study.config.study_path / "patch.json").read_text())["areas"]["test"]["country"] is None
assert (
json.loads((empty_study.config.study_path / "patch.json").read_text())["areas"]["test"]["country"] is None
)

area_manager.update_area_ui(study, "test", UpdateAreaUi(x=100, y=200, color_rgb=(255, 0, 100)))
assert empty_study.tree.get(["input", "areas", "test", "ui", "ui"]) == {
Expand Down Expand Up @@ -174,7 +176,9 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService):
RequestParameters(DEFAULT_ADMIN_USER),
)
assert (empty_study.config.study_path / "patch.json").exists()
assert json.loads((empty_study.config.study_path / "patch.json").read_text())["areas"]["test"]["country"] == "FR"
assert (
json.loads((empty_study.config.study_path / "patch.json").read_text())["areas"]["test"]["country"] == "FR"
)

area_manager.update_area_ui(study, "test", UpdateAreaUi(x=100, y=200, color_rgb=(255, 0, 100)))
variant_study_service.append_commands.assert_called_with(
Expand Down
10 changes: 5 additions & 5 deletions tests/study/storage/variantstudy/test_variant_study_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ def test_generate_task(

with patch("antarest.study.business.utils.get_current_user", return_value=DEFAULT_ADMIN_USER):
execute_or_add_commands(
variant_study,
file_study,
commands=[create_area_fr, create_st_storage],
storage_service=study_storage_service,
)
variant_study,
file_study,
commands=[create_area_fr, create_st_storage],
storage_service=study_storage_service,
)

## Run the "generate" task
actual_uui = variant_study_service.generate_task(
Expand Down

0 comments on commit cb0cd7c

Please sign in to comment.