diff --git a/antarest/study/storage/patch_service.py b/antarest/study/storage/patch_service.py index e3ece9071e..c52752ae25 100644 --- a/antarest/study/storage/patch_service.py +++ b/antarest/study/storage/patch_service.py @@ -1,27 +1,32 @@ -import logging +import json +import typing as t from pathlib import Path -from typing import Optional, Union from antarest.study.model import Patch, PatchOutputs, RawStudy, StudyAdditionalData from antarest.study.repository import StudyMetadataRepository from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.variantstudy.model.dbmodel import VariantStudy -logger = logging.getLogger(__name__) +PATCH_JSON = "patch.json" class PatchService: - def __init__(self, repository: Optional[StudyMetadataRepository] = None): + """ + Handle patch file ("patch.json") for a RawStudy or VariantStudy + """ + + def __init__(self, repository: t.Optional[StudyMetadataRepository] = None): self.repository = repository - def get(self, study: Union[RawStudy, VariantStudy], get_from_file: bool = False) -> Patch: - if not get_from_file: + def get(self, study: t.Union[RawStudy, VariantStudy], get_from_file: bool = False) -> Patch: + if not get_from_file and study.additional_data is not None: # the `study.additional_data.patch` field is optional - if patch_data := study.additional_data.patch: - return Patch.parse_raw(patch_data) + if study.additional_data.patch: + patch_obj = json.loads(study.additional_data.patch or "{}") + return Patch.parse_obj(patch_obj) patch = Patch() - patch_path = Path(study.path) / "patch.json" + patch_path = Path(study.path) / PATCH_JSON if patch_path.exists(): patch = Patch.parse_file(patch_path) @@ -29,14 +34,14 @@ def get(self, study: Union[RawStudy, VariantStudy], get_from_file: bool = False) def get_from_filestudy(self, file_study: FileStudy) -> Patch: patch = Patch() - patch_path = (Path(file_study.config.study_path)) / "patch.json" + patch_path = (Path(file_study.config.study_path)) / PATCH_JSON if patch_path.exists(): patch = Patch.parse_file(patch_path) return patch def set_reference_output( self, - study: Union[RawStudy, VariantStudy], + study: t.Union[RawStudy, VariantStudy], output_id: str, status: bool = True, ) -> None: @@ -47,12 +52,12 @@ def set_reference_output( patch.outputs = PatchOutputs(reference=output_id) self.save(study, patch) - def save(self, study: Union[RawStudy, VariantStudy], patch: Patch) -> None: + def save(self, study: t.Union[RawStudy, VariantStudy], patch: Patch) -> None: if self.repository: study.additional_data = study.additional_data or StudyAdditionalData() study.additional_data.patch = patch.json() self.repository.save(study) - patch_path = (Path(study.path)) / "patch.json" + patch_path = (Path(study.path)) / PATCH_JSON patch_path.parent.mkdir(parents=True, exist_ok=True) patch_path.write_text(patch.json()) diff --git a/antarest/study/storage/rawstudy/raw_study_service.py b/antarest/study/storage/rawstudy/raw_study_service.py index f5f98c97e5..4990d70bf6 100644 --- a/antarest/study/storage/rawstudy/raw_study_service.py +++ b/antarest/study/storage/rawstudy/raw_study_service.py @@ -1,10 +1,10 @@ import logging import shutil import time +import typing as t from datetime import datetime from pathlib import Path from threading import Thread -from typing import BinaryIO, List, Optional, Sequence from uuid import uuid4 from zipfile import ZipFile @@ -61,7 +61,7 @@ def __init__( ) self.cleanup_thread.start() - def update_from_raw_meta(self, metadata: RawStudy, fallback_on_default: Optional[bool] = False) -> None: + def update_from_raw_meta(self, metadata: RawStudy, fallback_on_default: t.Optional[bool] = False) -> None: """ Update metadata from study raw metadata Args: @@ -90,7 +90,7 @@ def update_from_raw_meta(self, metadata: RawStudy, fallback_on_default: Optional metadata.version = metadata.version or 0 metadata.created_at = metadata.created_at or datetime.utcnow() metadata.updated_at = metadata.updated_at or datetime.utcnow() - if not metadata.additional_data: + if metadata.additional_data is None: metadata.additional_data = StudyAdditionalData() metadata.additional_data.patch = metadata.additional_data.patch or Patch().json() metadata.additional_data.author = metadata.additional_data.author or "Unknown" @@ -148,7 +148,7 @@ def get_raw( self, metadata: RawStudy, use_cache: bool = True, - output_dir: Optional[Path] = None, + output_dir: t.Optional[Path] = None, ) -> FileStudy: """ Fetch a study object and its config @@ -163,7 +163,7 @@ def get_raw( study_path = self.get_study_path(metadata) return self.study_factory.create_from_fs(study_path, metadata.id, output_dir, use_cache=use_cache) - def get_synthesis(self, metadata: RawStudy, params: Optional[RequestParameters] = None) -> FileStudyTreeConfigDTO: + def get_synthesis(self, metadata: RawStudy, params: t.Optional[RequestParameters] = None) -> FileStudyTreeConfigDTO: self._check_study_exists(metadata) study_path = self.get_study_path(metadata) study = self.study_factory.create_from_fs(study_path, metadata.id) @@ -206,7 +206,7 @@ def copy( self, src_meta: RawStudy, dest_name: str, - groups: Sequence[str], + groups: t.Sequence[str], with_outputs: bool = False, ) -> RawStudy: """ @@ -223,7 +223,7 @@ def copy( """ self._check_study_exists(src_meta) - if not src_meta.additional_data: + if src_meta.additional_data is None: additional_data = StudyAdditionalData() else: additional_data = StudyAdditionalData( @@ -295,7 +295,7 @@ def delete_output(self, metadata: RawStudy, output_name: str) -> None: output_path.unlink(missing_ok=True) remove_from_cache(self.cache, metadata.id) - def import_study(self, metadata: RawStudy, stream: BinaryIO) -> Study: + def import_study(self, metadata: RawStudy, stream: t.BinaryIO) -> Study: """ Import study in the directory of the study. @@ -329,7 +329,7 @@ def export_study_flat( metadata: RawStudy, dst_path: Path, outputs: bool = True, - output_list_filter: Optional[List[str]] = None, + output_list_filter: t.Optional[t.List[str]] = None, denormalize: bool = True, ) -> None: try: @@ -352,7 +352,7 @@ def export_study_flat( def check_errors( self, metadata: RawStudy, - ) -> List[str]: + ) -> t.List[str]: """ Check study antares data integrity Args: diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index e59ef3fa94..582e251013 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -3,10 +3,10 @@ import logging import re import shutil +import typing as t from datetime import datetime from functools import reduce from pathlib import Path -from typing import Callable, List, Optional, Sequence, Tuple, cast from uuid import uuid4 from fastapi import HTTPException @@ -101,11 +101,11 @@ def get_command(self, study_id: str, command_id: str, params: RequestParameters) try: index = [command.id for command in study.commands].index(command_id) # Maybe add Try catch for this - return cast(CommandDTO, study.commands[index].to_dto()) + return t.cast(CommandDTO, study.commands[index].to_dto()) except ValueError: raise CommandNotFoundError(f"Command with id {command_id} not found") from None - def get_commands(self, study_id: str, params: RequestParameters) -> List[CommandDTO]: + def get_commands(self, study_id: str, params: RequestParameters) -> t.List[CommandDTO]: """ Get command lists Args: @@ -116,8 +116,8 @@ def get_commands(self, study_id: str, params: RequestParameters) -> List[Command study = self._get_variant_study(study_id, params) return [command.to_dto() for command in study.commands] - def _check_commands_validity(self, study_id: str, commands: List[CommandDTO]) -> List[ICommand]: - command_objects: List[ICommand] = [] + def _check_commands_validity(self, study_id: str, commands: t.List[CommandDTO]) -> t.List[ICommand]: + command_objects: t.List[ICommand] = [] for i, command in enumerate(commands): try: command_objects.extend(self.command_factory.to_command(command)) @@ -157,9 +157,9 @@ def append_command(self, study_id: str, command: CommandDTO, params: RequestPara def append_commands( self, study_id: str, - commands: List[CommandDTO], + commands: t.List[CommandDTO], params: RequestParameters, - ) -> List[str]: + ) -> t.List[str]: """ Add command to list of commands (at the end) Args: @@ -196,7 +196,7 @@ def append_commands( def replace_commands( self, study_id: str, - commands: List[CommandDTO], + commands: t.List[CommandDTO], params: RequestParameters, ) -> str: """ @@ -320,13 +320,13 @@ def export_commands_matrices(self, study_id: str, params: RequestParameters) -> lambda: reduce( lambda m, c: m + c.get_inner_matrices(), self.command_factory.to_command(command.to_dto()), - cast(List[str], []), + t.cast(t.List[str], []), ), lambda e: logger.warning(f"Failed to parse command {command}", exc_info=e), ) or [] } - return cast(MatrixService, self.command_factory.command_context.matrix_service).download_matrix_list( + return t.cast(MatrixService, self.command_factory.command_context.matrix_service).download_matrix_list( list(matrices), f"{study.name}_{study.id}_matrices", params ) @@ -410,7 +410,7 @@ def get_all_variants_children( def walk_children( self, parent_id: str, - fun: Callable[[VariantStudy], None], + fun: t.Callable[[VariantStudy], None], bottom_first: bool, ) -> None: study = self._get_variant_study( @@ -426,13 +426,13 @@ def walk_children( if bottom_first: fun(study) - def get_variants_parents(self, id: str, params: RequestParameters) -> List[StudyMetadataDTO]: - output_list: List[StudyMetadataDTO] = self._get_variants_parents(id, params) + def get_variants_parents(self, id: str, params: RequestParameters) -> t.List[StudyMetadataDTO]: + output_list: t.List[StudyMetadataDTO] = self._get_variants_parents(id, params) if output_list: output_list = output_list[1:] return output_list - def get_direct_parent(self, id: str, params: RequestParameters) -> Optional[StudyMetadataDTO]: + def get_direct_parent(self, id: str, params: RequestParameters) -> t.Optional[StudyMetadataDTO]: study = self._get_variant_study(id, params, raw_study_accepted=True) if study.parent_id is not None: parent = self._get_variant_study(study.parent_id, params, raw_study_accepted=True) @@ -447,7 +447,7 @@ def get_direct_parent(self, id: str, params: RequestParameters) -> Optional[Stud ) return None - def _get_variants_parents(self, id: str, params: RequestParameters) -> List[StudyMetadataDTO]: + def _get_variants_parents(self, id: str, params: RequestParameters) -> t.List[StudyMetadataDTO]: study = self._get_variant_study(id, params, raw_study_accepted=True) metadata = ( self.get_study_information( @@ -458,7 +458,7 @@ def _get_variants_parents(self, id: str, params: RequestParameters) -> List[Stud study, ) ) - output_list: List[StudyMetadataDTO] = [metadata] + output_list: t.List[StudyMetadataDTO] = [metadata] if study.parent_id is not None: output_list.extend( self._get_variants_parents( @@ -530,16 +530,15 @@ def create_variant_study(self, uuid: str, name: str, params: RequestParameters) assert_permission(params.user, study, StudyPermissionType.READ) new_id = str(uuid4()) study_path = str(self.config.get_workspace_path() / new_id) - if study.additional_data: - # noinspection PyArgumentList + if study.additional_data is None: + additional_data = StudyAdditionalData() + else: additional_data = StudyAdditionalData( horizon=study.additional_data.horizon, author=study.additional_data.author, patch=study.additional_data.patch, ) - else: - additional_data = StudyAdditionalData() - # noinspection PyArgumentList + variant_study = VariantStudy( id=new_id, name=name, @@ -653,7 +652,7 @@ def generate_study_config( self, variant_study_id: str, params: RequestParameters, - ) -> Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]: + ) -> t.Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]: # Get variant study variant_study = self._get_variant_study(variant_study_id, params) @@ -667,8 +666,8 @@ def _generate_study_config( self, original_study: VariantStudy, metadata: VariantStudy, - config: Optional[FileStudyTreeConfig], - ) -> Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]: + config: t.Optional[FileStudyTreeConfig], + ) -> t.Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]: parent_study = self.repository.get(metadata.parent_id) if parent_study is None: raise StudyNotFoundError(metadata.parent_id) @@ -698,9 +697,9 @@ def _get_commands_and_notifier( variant_study: VariantStudy, notifier: TaskUpdateNotifier, from_index: int = 0, - ) -> Tuple[List[List[ICommand]], Callable[[int, bool, str], None]]: + ) -> t.Tuple[t.List[t.List[ICommand]], t.Callable[[int, bool, str], None]]: # Generate - commands: List[List[ICommand]] = self._to_commands(variant_study, from_index) + commands: t.List[t.List[ICommand]] = self._to_commands(variant_study, from_index) def notify(command_index: int, command_result: bool, command_message: str) -> None: try: @@ -727,8 +726,8 @@ def notify(command_index: int, command_result: bool, command_message: str) -> No return commands, notify - def _to_commands(self, metadata: VariantStudy, from_index: int = 0) -> List[List[ICommand]]: - commands: List[List[ICommand]] = [ + def _to_commands(self, metadata: VariantStudy, from_index: int = 0) -> t.List[t.List[ICommand]]: + commands: t.List[t.List[ICommand]] = [ self.command_factory.to_command(command_block.to_dto()) for index, command_block in enumerate(metadata.commands) if from_index <= index @@ -740,7 +739,7 @@ def _generate_config( variant_study: VariantStudy, config: FileStudyTreeConfig, notifier: TaskUpdateNotifier = noop_notifier, - ) -> Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]: + ) -> t.Tuple[GenerationResultInfoDTO, FileStudyTreeConfig]: commands, notify = self._get_commands_and_notifier(variant_study=variant_study, notifier=notifier) return self.generator.generate_config(commands, config, variant_study, notifier=notify) @@ -809,7 +808,7 @@ def copy( self, src_meta: VariantStudy, dest_name: str, - groups: Sequence[str], + groups: t.Sequence[str], with_outputs: bool = False, ) -> VariantStudy: """ @@ -826,16 +825,14 @@ def copy( """ new_id = str(uuid4()) study_path = str(self.config.get_workspace_path() / new_id) - if src_meta.additional_data: - # noinspection PyArgumentList + if src_meta.additional_data is None: + additional_data = StudyAdditionalData() + else: additional_data = StudyAdditionalData( horizon=src_meta.additional_data.horizon, author=src_meta.additional_data.author, patch=src_meta.additional_data.patch, ) - else: - additional_data = StudyAdditionalData() - # noinspection PyArgumentList dst_meta = VariantStudy( id=new_id, name=dest_name, @@ -893,7 +890,7 @@ def _safe_generation(self, metadata: VariantStudy, timeout: int = DEFAULT_AWAIT_ @staticmethod def _get_snapshot_last_executed_command_index( study: VariantStudy, - ) -> Optional[int]: + ) -> t.Optional[int]: if study.snapshot and study.snapshot.last_executed_command: last_executed_command_index = [command.id for command in study.commands].index( study.snapshot.last_executed_command @@ -905,7 +902,7 @@ def get_raw( self, metadata: VariantStudy, use_cache: bool = True, - output_dir: Optional[Path] = None, + output_dir: t.Optional[Path] = None, ) -> FileStudy: """ Fetch a study raw tree object and its config @@ -925,7 +922,7 @@ def get_raw( use_cache=use_cache, ) - def get_study_sim_result(self, study: VariantStudy) -> List[StudySimResultDTO]: + def get_study_sim_result(self, study: VariantStudy) -> t.List[StudySimResultDTO]: """ Get global result information Args: @@ -988,7 +985,7 @@ def export_study_flat( metadata: VariantStudy, dst_path: Path, outputs: bool = True, - output_list_filter: Optional[List[str]] = None, + output_list_filter: t.Optional[t.List[str]] = None, denormalize: bool = True, ) -> None: self._safe_generation(metadata) @@ -1009,7 +1006,7 @@ def export_study_flat( def get_synthesis( self, metadata: VariantStudy, - params: Optional[RequestParameters] = None, + params: t.Optional[RequestParameters] = None, ) -> FileStudyTreeConfigDTO: """ Return study synthesis diff --git a/tests/storage/business/test_patch_service.py b/tests/storage/business/test_patch_service.py index 7f792a057b..ed7dd6c444 100644 --- a/tests/storage/business/test_patch_service.py +++ b/tests/storage/business/test_patch_service.py @@ -51,7 +51,7 @@ class TestPatchService: @with_db_context @pytest.mark.parametrize("get_from_file", [True, False]) @pytest.mark.parametrize("file_data", ["", PATCH_CONTENT]) - @pytest.mark.parametrize("patch_data", ["", PATCH_CONTENT]) + @pytest.mark.parametrize("patch_data", [None, "", PATCH_CONTENT]) def test_get( self, tmp_path: Path, @@ -67,7 +67,15 @@ def test_get( patch_json.write_text(file_data, encoding="utf-8") # Prepare a RAW study - # noinspection PyArgumentList + additional_data = ( + None + if patch_data is None + else StudyAdditionalData( + author="john.doe", + horizon="foo-horizon", + patch=patch_data, + ) + ) raw_study = RawStudy( id=study_id, name="my_study", @@ -76,11 +84,7 @@ def test_get( created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), version="840", - additional_data=StudyAdditionalData( - author="john.doe", - horizon="foo-horizon", - patch=patch_data, - ), + additional_data=additional_data, archived=False, owner=None, groups=[],