From fa8bf8f27c6384bf9d659c1be37ad98ddf48252b Mon Sep 17 00:00:00 2001 From: larsevj Date: Tue, 10 Dec 2024 14:09:53 +0100 Subject: [PATCH] Add pyupgrade UP rule to ruff --- docs/ert/conf.py | 1 - docs/everest/conf.py | 1 - pyproject.toml | 4 + src/_ert/async_utils.py | 11 +- src/_ert/forward_model_runner/cli.py | 8 +- src/_ert/forward_model_runner/client.py | 33 ++--- .../forward_model_step.py | 20 ++- .../forward_model_runner/reporting/event.py | 4 +- .../forward_model_runner/reporting/file.py | 4 +- .../reporting/interactive.py | 12 +- .../forward_model_runner/reporting/message.py | 20 +-- .../reporting/statemachine.py | 6 +- src/_ert/forward_model_runner/runner.py | 10 +- src/_ert/threading.py | 5 +- src/ert/__main__.py | 23 ++-- src/ert/analysis/_es_update.py | 35 +++-- src/ert/analysis/misfit_preprocessor.py | 17 +-- src/ert/analysis/snapshots.py | 10 +- src/ert/cli/main.py | 6 +- src/ert/cli/monitor.py | 19 ++- src/ert/config/_get_num_cpu.py | 11 +- src/ert/config/_option_dict.py | 4 +- src/ert/config/_read_summary.py | 58 ++++----- src/ert/config/analysis_config.py | 18 +-- src/ert/config/capture_validation.py | 3 +- src/ert/config/design_matrix.py | 41 +++--- src/ert/config/ert_config.py | 120 ++++++++---------- src/ert/config/ert_plugin.py | 4 +- src/ert/config/ert_script.py | 21 +-- src/ert/config/ext_param_config.py | 5 +- src/ert/config/external_ert_script.py | 4 +- src/ert/config/field.py | 26 ++-- src/ert/config/forward_model_step.py | 8 +- src/ert/config/gen_data_config.py | 18 +-- src/ert/config/gen_kw_config.py | 84 ++++++------ src/ert/config/general_observation.py | 9 +- src/ert/config/model_config.py | 6 +- src/ert/config/observation_vector.py | 9 +- src/ert/config/observations.py | 3 +- src/ert/config/parsing/config_dict.py | 4 +- src/ert/config/parsing/config_errors.py | 5 +- .../parsing/config_schema_deprecations.py | 4 +- src/ert/config/parsing/config_schema_item.py | 33 +++-- src/ert/config/parsing/context_values.py | 10 +- src/ert/config/parsing/deprecation_info.py | 8 +- src/ert/config/parsing/error_info.py | 3 +- src/ert/config/parsing/file_context_token.py | 6 +- .../config/parsing/forward_model_schema.py | 6 +- src/ert/config/parsing/observations_parser.py | 2 +- src/ert/config/parsing/schema_dict.py | 8 +- src/ert/config/parsing/types.py | 12 +- src/ert/config/queue_config.py | 79 ++++++------ src/ert/config/refcase.py | 3 +- src/ert/config/response_config.py | 4 +- src/ert/config/responses_index.py | 12 +- src/ert/config/summary_config.py | 14 +- src/ert/config/workflow.py | 21 +-- src/ert/config/workflow_job.py | 8 +- src/ert/dark_storage/client/_session.py | 5 +- src/ert/dark_storage/client/async_client.py | 4 +- src/ert/dark_storage/client/client.py | 4 +- src/ert/dark_storage/common.py | 3 +- src/ert/dark_storage/compute/misfits.py | 2 +- .../dark_storage/endpoints/compute/misfits.py | 6 +- src/ert/dark_storage/endpoints/experiments.py | 9 +- .../dark_storage/endpoints/observations.py | 5 +- src/ert/dark_storage/endpoints/records.py | 5 +- src/ert/dark_storage/enkf.py | 3 +- src/ert/dark_storage/json_schema/ensemble.py | 3 +- .../dark_storage/json_schema/experiment.py | 3 +- .../dark_storage/json_schema/observation.py | 3 +- src/ert/dark_storage/json_schema/prior.py | 28 ++-- src/ert/dark_storage/json_schema/record.py | 5 +- src/ert/dark_storage/json_schema/update.py | 4 +- src/ert/dark_storage/security.py | 3 +- src/ert/data/_measured_data.py | 6 +- src/ert/enkf_main.py | 17 +-- src/ert/ensemble_evaluator/_ensemble.py | 10 +- .../ensemble_evaluator/_wait_for_evaluator.py | 13 +- src/ert/ensemble_evaluator/config.py | 18 ++- src/ert/ensemble_evaluator/evaluator.py | 25 ++-- .../evaluator_connection_info.py | 5 +- src/ert/ensemble_evaluator/event.py | 9 +- src/ert/ensemble_evaluator/monitor.py | 25 ++-- src/ert/ensemble_evaluator/snapshot.py | 63 +++++---- src/ert/field_utils/grdecl_io.py | 5 +- src/ert/field_utils/roff_io.py | 10 +- src/ert/gui/about_dialog.py | 4 +- src/ert/gui/ertnotifier.py | 10 +- src/ert/gui/ertwidgets/__init__.py | 3 +- .../analysismodulevariablespanel.py | 4 +- src/ert/gui/ertwidgets/checklist.py | 10 +- src/ert/gui/ertwidgets/closabledialog.py | 6 +- src/ert/gui/ertwidgets/copy_button.py | 7 +- src/ert/gui/ertwidgets/copyablelabel.py | 5 +- .../ertwidgets/create_experiment_dialog.py | 10 +- src/ert/gui/ertwidgets/ensembleselector.py | 5 +- src/ert/gui/ertwidgets/listeditbox.py | 14 +- src/ert/gui/ertwidgets/message_box.py | 8 +- .../models/activerealizationsmodel.py | 2 +- src/ert/gui/ertwidgets/models/path_model.py | 4 +- .../models/selectable_list_model.py | 10 +- .../ertwidgets/models/targetensemblemodel.py | 6 +- src/ert/gui/ertwidgets/models/text_model.py | 4 +- src/ert/gui/ertwidgets/models/valuemodel.py | 8 +- src/ert/gui/ertwidgets/pathchooser.py | 4 +- src/ert/gui/ertwidgets/searchbox.py | 8 +- src/ert/gui/ertwidgets/stringbox.py | 6 +- src/ert/gui/ertwidgets/textbox.py | 6 +- src/ert/gui/ertwidgets/validationsupport.py | 20 +-- src/ert/gui/main.py | 23 ++-- src/ert/gui/main_window.py | 4 +- src/ert/gui/model/node.py | 34 ++--- src/ert/gui/model/snapshot.py | 9 +- .../simulation/combobox_with_description.py | 10 +- .../gui/simulation/experiment_config_panel.py | 6 +- src/ert/gui/simulation/experiment_panel.py | 26 ++-- .../multiple_data_assimilation_panel.py | 4 +- src/ert/gui/simulation/queue_emitter.py | 11 +- src/ert/gui/simulation/run_dialog.py | 8 +- src/ert/gui/simulation/view/realization.py | 10 +- src/ert/gui/simulation/view/update.py | 9 +- src/ert/gui/suggestor/suggestor.py | 3 +- .../design_matrix/design_matrix_panel.py | 4 +- src/ert/gui/tools/event_viewer/panel.py | 2 +- src/ert/gui/tools/event_viewer/tool.py | 6 +- src/ert/gui/tools/export/export_panel.py | 15 +-- src/ert/gui/tools/file/file_dialog.py | 5 +- src/ert/gui/tools/file/file_update_worker.py | 5 +- .../tools/load_results/load_results_tool.py | 6 +- .../manage_experiments/storage_info_widget.py | 7 +- .../tools/manage_experiments/storage_model.py | 8 +- .../manage_experiments/storage_widget.py | 2 +- .../gui/tools/plot/customize/color_chooser.py | 8 +- .../plot/customize/customization_view.py | 19 +-- .../plot/customize/customize_plot_dialog.py | 5 +- .../customize/limits_customization_view.py | 56 ++++---- .../statistics_customization_view.py | 6 +- .../gui/tools/plot/customize/style_chooser.py | 2 +- .../gui/tools/plot/data_type_keys_widget.py | 4 +- .../gui/tools/plot/data_type_proxy_model.py | 4 +- src/ert/gui/tools/plot/plot_api.py | 34 +++-- .../plot/plot_ensemble_selection_widget.py | 17 +-- src/ert/gui/tools/plot/plot_widget.py | 18 +-- src/ert/gui/tools/plot/plot_window.py | 34 +++-- .../gui/tools/plot/plottery/plot_config.py | 36 +++--- .../plot/plottery/plot_config_history.py | 6 +- .../gui/tools/plot/plottery/plot_context.py | 24 ++-- .../gui/tools/plot/plottery/plot_limits.py | 54 ++++---- src/ert/gui/tools/plot/plottery/plot_style.py | 5 +- src/ert/gui/tools/plot/plottery/plots/cesp.py | 46 +++---- .../tools/plot/plottery/plots/distribution.py | 16 +-- .../gui/tools/plot/plottery/plots/ensemble.py | 8 +- .../tools/plot/plottery/plots/gaussian_kde.py | 8 +- .../tools/plot/plottery/plots/histogram.py | 27 ++-- .../tools/plot/plottery/plots/plot_tools.py | 24 ++-- .../tools/plot/plottery/plots/statistics.py | 12 +- .../gui/tools/plot/plottery/plots/std_dev.py | 6 +- .../tools/plot/widgets/clearable_line_edit.py | 12 +- .../plot/widgets/copy_style_to_dialog.py | 10 +- .../tools/plot/widgets/custom_date_edit.py | 5 +- .../gui/tools/plot/widgets/filter_popup.py | 12 +- .../plot/widgets/filterable_kw_list_model.py | 8 +- src/ert/gui/tools/plugins/plugin_handler.py | 3 +- src/ert/gui/tools/plugins/plugin_runner.py | 9 +- src/ert/gui/tools/plugins/plugins_tool.py | 4 +- .../gui/tools/plugins/process_job_dialog.py | 18 +-- src/ert/gui/tools/search_bar/search_bar.py | 16 +-- src/ert/gui/tools/tool.py | 8 +- .../tools/workflows/run_workflow_widget.py | 7 +- .../gui/tools/workflows/workflow_dialog.py | 6 +- src/ert/libres_facade.py | 7 +- src/ert/load_status.py | 4 +- src/ert/logging/__init__.py | 8 +- src/ert/namespace.py | 6 +- src/ert/plugins/__init__.py | 3 +- .../forward_model_steps.py | 10 +- src/ert/plugins/hook_implementations/jobs.py | 5 +- .../workflows/disable_parameters.py | 4 +- .../workflows/export_misfit_data.py | 4 +- .../workflows/export_runpath.py | 4 +- .../workflows/misfit_preprocessor.py | 4 +- .../plugins/hook_specifications/__init__.py | 5 +- .../forward_model_steps.py | 6 +- .../hook_specifications/help_resources.py | 4 +- src/ert/plugins/hook_specifications/jobs.py | 8 +- .../hook_specifications/site_config.py | 4 +- src/ert/plugins/plugin_manager.py | 17 +-- src/ert/plugins/workflow_config.py | 11 +- .../forward_models/res/script/ecl_config.py | 14 +- .../forward_models/res/script/ecl_run.py | 8 +- .../forward_models/template_render.py | 2 +- .../shell_scripts/careful_copy_file.py | 4 +- .../resources/shell_scripts/copy_directory.py | 4 +- src/ert/resources/shell_scripts/copy_file.py | 4 +- .../shell_scripts/delete_directory.py | 4 +- .../resources/shell_scripts/delete_file.py | 4 +- .../resources/shell_scripts/make_directory.py | 2 +- .../resources/shell_scripts/move_directory.py | 4 +- src/ert/resources/shell_scripts/move_file.py | 4 +- src/ert/resources/shell_scripts/symlink.py | 4 +- .../scripts/gen_data_rft_export.py | 2 +- src/ert/run_arg.py | 6 +- src/ert/run_models/base_run_model.py | 33 ++--- src/ert/run_models/event.py | 9 +- src/ert/run_models/everest_run_model.py | 28 ++-- .../run_models/iterated_ensemble_smoother.py | 10 +- src/ert/run_models/model_factory.py | 4 +- src/ert/run_models/single_test_run.py | 6 +- src/ert/runpaths.py | 2 +- src/ert/scheduler/driver.py | 14 +- src/ert/scheduler/event.py | 3 +- src/ert/scheduler/job.py | 4 +- src/ert/scheduler/local_driver.py | 12 +- src/ert/scheduler/lsf_driver.py | 8 +- src/ert/scheduler/openpbs_driver.py | 5 +- src/ert/scheduler/scheduler.py | 54 ++++---- src/ert/scheduler/slurm_driver.py | 52 ++++---- src/ert/services/_base_service.py | 36 ++---- src/ert/services/storage_service.py | 15 ++- src/ert/shared/_doc_utils/everest_jobs.py | 8 +- .../_doc_utils/forward_model_documentation.py | 9 +- src/ert/shared/net_utils.py | 7 +- src/ert/shared/plugins/plugin_response.py | 9 +- src/ert/shared/storage/connection.py | 6 +- src/ert/shared/storage/extraction.py | 6 +- src/ert/storage/__init__.py | 3 +- src/ert/storage/local_ensemble.py | 33 ++--- src/ert/storage/local_experiment.py | 16 +-- src/ert/storage/local_storage.py | 11 +- src/ert/storage/migration/to5.py | 2 +- src/ert/storage/migration/to7.py | 2 +- src/ert/storage/migration/to8.py | 6 +- src/ert/storage/migration/to9.py | 4 +- src/ert/storage/mode.py | 7 +- src/ert/substitutions.py | 5 +- .../ensemble_realizations_argument.py | 4 +- src/ert/validation/integer_argument.py | 5 +- src/ert/validation/range_string_argument.py | 4 +- src/ert/validation/rangestring.py | 2 +- src/ert/validation/validation_status.py | 7 +- src/ert/workflow_runner.py | 28 ++-- src/everest/bin/config_branch_script.py | 4 +- src/everest/bin/main.py | 4 +- src/everest/bin/utils.py | 10 +- src/everest/config/control_config.py | 37 +++--- src/everest/config/control_variable_config.py | 35 +++-- src/everest/config/cvar_config.py | 6 +- src/everest/config/environment_config.py | 16 +-- src/everest/config/everest_config.py | 59 +++++---- src/everest/config/export_config.py | 14 +- src/everest/config/has_ert_queue_options.py | 6 +- src/everest/config/input_constraint_config.py | 12 +- src/everest/config/install_data_config.py | 3 +- src/everest/config/install_template_config.py | 4 +- src/everest/config/model_config.py | 8 +- .../config/objective_function_config.py | 12 +- src/everest/config/optimization_config.py | 32 ++--- .../config/output_constraint_config.py | 12 +- src/everest/config/sampler_config.py | 6 +- src/everest/config/server_config.py | 16 +-- src/everest/config/simulator_config.py | 64 +++++----- src/everest/config/validation_utils.py | 37 +++--- src/everest/config/well_config.py | 5 +- src/everest/config/workflow_config.py | 6 +- src/everest/config_file_loader.py | 16 +-- src/everest/detached/__init__.py | 25 ++-- src/everest/detached/jobs/everserver.py | 12 +- .../docs/generate_docs_from_config_spec.py | 9 +- src/everest/export.py | 30 ++--- src/everest/jobs/__init__.py | 10 +- src/everest/jobs/io/__init__.py | 2 +- src/everest/jobs/templating/render.py | 4 +- src/everest/optimizer/everest2ropt.py | 93 +++++++------- src/everest/plugins/everest_plugin_manager.py | 6 +- src/everest/plugins/hook_specs.py | 7 +- src/everest/queue_driver/queue_driver.py | 6 +- src/everest/simulator/everest_to_ert.py | 17 +-- src/everest/simulator/simulator_cache.py | 13 +- src/everest/util/__init__.py | 4 +- src/everest/util/async_run.py | 2 +- src/everest/util/forward_models.py | 8 +- .../workflows/jobs/realization_number.py | 2 +- test-data/ert/heat_equation/generate_files.py | 8 +- test-data/ert/heat_equation/heat_equation.py | 21 ++- .../forward_models/snake_oil_simulator.py | 4 +- .../forward_models/snake_oil_simulator.py | 6 +- .../everest/math_func/jobs/adv_distance3.py | 2 +- .../math_func/jobs/adv_dump_controls.py | 2 +- test-data/everest/math_func/jobs/discrete.py | 2 +- test-data/everest/math_func/jobs/distance3.py | 2 +- .../everest/math_func/jobs/dump_controls.py | 2 +- tests/ert/__init__.py | 31 ++--- tests/ert/conftest.py | 8 +- tests/ert/performance_tests/test_analysis.py | 4 +- .../test_dark_storage_performance.py | 3 +- .../test_obs_and_responses_performance.py | 3 +- .../performance_tests/test_read_summary.py | 4 +- .../ui_tests/cli/analysis/test_es_update.py | 4 +- tests/ert/ui_tests/cli/test_cli.py | 10 +- .../ui_tests/cli/test_parameter_passing.py | 12 +- tests/ert/ui_tests/gui/conftest.py | 17 ++- .../gui/test_full_manual_update_workflow.py | 7 +- .../gui/test_load_results_manually.py | 5 +- tests/ert/ui_tests/gui/test_main_window.py | 10 +- .../ert/ui_tests/gui/test_missing_runpath.py | 4 +- .../gui/test_restart_ensemble_experiment.py | 3 +- ...est_restart_no_responses_and_parameters.py | 4 +- .../ert/ui_tests/gui/test_single_test_run.py | 5 +- tests/ert/ui_tests/gui/test_workflow_tool.py | 4 +- .../config/config_dict_generator.py | 40 +++--- .../ert/unit_tests/config/egrid_generator.py | 16 +-- .../config/observations_generator.py | 27 ++-- .../config/parsing/test_lark_parser.py | 6 +- .../ert/unit_tests/config/test_ert_config.py | 2 +- .../config/test_forward_model_data_to_json.py | 9 +- .../unit_tests/config/test_gen_data_config.py | 3 +- .../unit_tests/config/test_gen_kw_config.py | 2 +- .../unit_tests/config/test_observations.py | 4 +- .../config/test_parser_error_collection.py | 28 ++-- .../unit_tests/config/test_read_summary.py | 4 +- .../config/test_transfer_functions.py | 46 +++---- tests/ert/unit_tests/conftest.py | 5 +- .../ensemble_evaluator_utils.py | 4 +- .../ensemble_evaluator/test_scheduler.py | 6 +- .../ensemble_evaluator/test_snapshot.py | 5 +- .../test_file_reporter.py | 28 ++-- .../test_forward_model_runner.py | 6 +- .../test_forward_model_step.py | 2 +- .../gui/simulation/view/test_legend.py | 6 +- .../gui/tools/plot/test_plot_window.py | 5 +- .../unit_tests/plugins/test_plugin_manager.py | 4 +- tests/ert/unit_tests/resources/test_shell.py | 10 +- .../unit_tests/resources/test_templating.py | 6 +- tests/ert/unit_tests/scheduler/bin/bhist.py | 9 +- tests/ert/unit_tests/scheduler/bin/bjobs.py | 8 +- tests/ert/unit_tests/scheduler/bin/qstat.py | 6 +- tests/ert/unit_tests/scheduler/bin/sacct.py | 4 +- .../ert/unit_tests/scheduler/bin/scontrol.py | 4 +- tests/ert/unit_tests/scheduler/bin/squeue.py | 4 +- tests/ert/unit_tests/scheduler/conftest.py | 3 +- tests/ert/unit_tests/scheduler/test_job.py | 3 +- .../unit_tests/scheduler/test_lsf_driver.py | 31 ++--- .../scheduler/test_openpbs_driver.py | 5 +- .../unit_tests/scheduler/test_scheduler.py | 7 +- .../unit_tests/scheduler/test_slurm_driver.py | 14 +- .../ert/unit_tests/storage/create_runpath.py | 8 +- .../unit_tests/storage/test_local_storage.py | 26 ++-- tests/ert/unit_tests/storage/test_mode.py | 7 +- .../ert/unit_tests/test_run_path_creation.py | 4 +- tests/ert/unit_tests/test_tracking.py | 7 +- .../workflow_runner/test_workflow.py | 4 +- .../workflow_runner/test_workflow_job.py | 2 +- tests/everest/conftest.py | 10 +- .../entry_points/test_config_branch_entry.py | 4 +- tests/everest/entry_points/test_everexport.py | 2 +- .../functional/test_main_everest_entry.py | 9 +- tests/everest/test_api_snapshots.py | 4 +- tests/everest/test_config_validation.py | 6 +- tests/everest/test_controls.py | 3 +- tests/everest/test_egg_simulation.py | 4 +- tests/everest/test_everest_config.py | 3 +- tests/everest/test_fm_plugins.py | 12 +- tests/everest/test_logging.py | 5 +- tests/everest/test_repo_configs.py | 2 +- tests/everest/test_templating.py | 2 +- tests/everest/test_workflows.py | 4 +- tests/everest/utils/__init__.py | 4 +- .../utils/test_pydantic_doc_generation.py | 2 +- 369 files changed, 1965 insertions(+), 2283 deletions(-) diff --git a/docs/ert/conf.py b/docs/ert/conf.py index fa79b90b611..71a069bd046 100644 --- a/docs/ert/conf.py +++ b/docs/ert/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # diff --git a/docs/everest/conf.py b/docs/everest/conf.py index 2f42f0d3c1a..fe911e8c713 100644 --- a/docs/everest/conf.py +++ b/docs/everest/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # diff --git a/pyproject.toml b/pyproject.toml index e353f4ceb9c..00295d9142d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -194,6 +194,7 @@ select = [ "C4", # flake8-comprehensions "ASYNC", # flake8-async "RUF", # ruff specific rules + "UP", # pyupgrade ] preview = true ignore = [ @@ -210,6 +211,8 @@ ignore = [ "PLR0904", # too-many-public-methods "PLR1702", # too-many-nested-blocks "PLW3201", # bad-dunder-method-name + "UP032", # f-string + "UP031", # printf-string-formatting ] # Allow EN DASH (U+2013) @@ -221,6 +224,7 @@ allowed-confusables = ["–"] "RUF029", # unused-async "RUF018", # assignment-in-assert "RUF006", # asyncio-dangling-task + "PLW1508", # Invalid type of environment variable default ] "src/ert/dark_storage/json_schema/__init__.py" = ["F401"] "src/ert/dark_storage/*" = ["RUF029"] # unused-async diff --git a/src/_ert/async_utils.py b/src/_ert/async_utils.py index c8e18cc6800..efd49b958b8 100644 --- a/src/_ert/async_utils.py +++ b/src/_ert/async_utils.py @@ -3,8 +3,9 @@ import asyncio import logging import traceback +from collections.abc import Coroutine, Generator from contextlib import suppress -from typing import Any, Coroutine, Generator, TypeVar, Union +from typing import Any, TypeVar logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ def get_running_loop() -> asyncio.AbstractEventLoop: def _create_task( loop: asyncio.AbstractEventLoop, - coro: Union[Coroutine[Any, Any, _T], Generator[Any, None, _T]], + coro: Coroutine[Any, Any, _T] | Generator[Any, None, _T], ) -> asyncio.Task[_T]: assert asyncio.iscoroutine(coro) task = asyncio.Task(coro, loop=loop) @@ -47,9 +48,7 @@ def _done_callback(task: asyncio.Task[_T_co]) -> None: traceback.format_exception(None, exc, exc.__traceback__) ) logger.error( - ( - f"Exception in scheduler task {task.get_name()}: {exc}\n" - f"Traceback: {exc_traceback}" - ) + f"Exception in scheduler task {task.get_name()}: {exc}\n" + f"Traceback: {exc_traceback}" ) raise exc diff --git a/src/_ert/forward_model_runner/cli.py b/src/_ert/forward_model_runner/cli.py index 0398ca68b73..6f99ef576c9 100644 --- a/src/_ert/forward_model_runner/cli.py +++ b/src/_ert/forward_model_runner/cli.py @@ -24,8 +24,8 @@ def _setup_reporters( ee_token=None, ee_cert_path=None, experiment_id=None, -) -> typing.List[reporting.Reporter]: - reporters: typing.List[reporting.Reporter] = [] +) -> list[reporting.Reporter]: + reporters: list[reporting.Reporter] = [] if is_interactive_run: reporters.append(reporting.Interactive()) elif ens_id and experiment_id is None: @@ -77,10 +77,10 @@ def _wait_for_retry(): def _read_jobs_file(retry=True): try: - with open(JOBS_FILE, "r", encoding="utf-8") as json_file: + with open(JOBS_FILE, encoding="utf-8") as json_file: return json.load(json_file) except json.JSONDecodeError as e: - raise IOError("Job Runner cli failed to load JSON-file.") from e + raise OSError("Job Runner cli failed to load JSON-file.") from e except FileNotFoundError as e: if retry: logger.error(f"Could not find file {JOBS_FILE}, retrying") diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index e04e3be6ebd..ea798522b86 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -1,7 +1,7 @@ import asyncio import logging import ssl -from typing import Any, AnyStr, Optional, Self, Union +from typing import Any, AnyStr, Self from websockets.asyncio.client import ClientConnection, connect from websockets.datastructures import Headers @@ -50,10 +50,10 @@ async def __aexit__( def __init__( self, url: str, - token: Optional[str] = None, - cert: Optional[Union[str, bytes]] = None, - max_retries: Optional[int] = None, - timeout_multiplier: Optional[int] = None, + token: str | None = None, + cert: str | bytes | None = None, + max_retries: int | None = None, + timeout_multiplier: int | None = None, ) -> None: if max_retries is None: max_retries = self.DEFAULT_MAX_RETRIES @@ -72,7 +72,7 @@ def __init__( # if True it will enforce TLS, and if you want to use self signed # certificates you need to pass an ssl_context with the certificate # loaded. - self._ssl_context: Optional[Union[bool, ssl.SSLContext]] = None + self._ssl_context: bool | ssl.SSLContext | None = None if cert is not None: self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self._ssl_context.load_verify_locations(cadata=cert) @@ -81,7 +81,7 @@ def __init__( self._max_retries = max_retries self._timeout_multiplier = timeout_multiplier - self.websocket: Optional[ClientConnection] = None + self.websocket: ClientConnection | None = None self.loop = new_event_loop() async def get_websocket(self) -> ClientConnection: @@ -103,33 +103,28 @@ async def _send(self, msg: AnyStr) -> None: await self.websocket.send(msg) return except ConnectionClosedOK as exception: - _error_msg = ( + error_msg = ( f"Connection closed received from the server {self.url}! " f" Exception from {type(exception)}: {exception!s}" ) - raise ClientConnectionClosedOK(_error_msg) from exception - except ( - InvalidHandshake, - InvalidURI, - OSError, - asyncio.TimeoutError, - ) as exception: + raise ClientConnectionClosedOK(error_msg) from exception + except (TimeoutError, InvalidHandshake, InvalidURI, OSError) as exception: if retry == self._max_retries: - _error_msg = ( + error_msg = ( f"Not able to establish the " f"websocket connection {self.url}! Max retries reached!" " Check for firewall issues." f" Exception from {type(exception)}: {exception!s}" ) - raise ClientConnectionError(_error_msg) from exception + raise ClientConnectionError(error_msg) from exception except ConnectionClosedError as exception: if retry == self._max_retries: - _error_msg = ( + error_msg = ( f"Not been able to send the event" f" to {self.url}! Max retries reached!" f" Exception from {type(exception)}: {exception!s}" ) - raise ClientConnectionError(_error_msg) from exception + raise ClientConnectionError(error_msg) from exception await asyncio.sleep(0.2 + self._timeout_multiplier * retry) self.websocket = None diff --git a/src/_ert/forward_model_runner/forward_model_step.py b/src/_ert/forward_model_runner/forward_model_step.py index adf5581f8f1..e15a45c537c 100644 --- a/src/_ert/forward_model_runner/forward_model_step.py +++ b/src/_ert/forward_model_runner/forward_model_step.py @@ -9,10 +9,11 @@ import socket import sys import time +from collections.abc import Generator, Sequence from datetime import datetime as dt from pathlib import Path from subprocess import Popen, run -from typing import TYPE_CHECKING, Generator, Sequence, cast +from typing import TYPE_CHECKING, cast from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess @@ -91,8 +92,7 @@ def __init__( def run(self) -> Generator[Start | Exited | Running | None]: try: - for msg in self._run(): - yield msg + yield from self._run() except Exception as e: yield Exited(self, exit_code=1).with_error(str(e)) @@ -293,11 +293,9 @@ def handle_process_timeout_and_create_exited_msg( os.killpg(process_group_id, signal.SIGKILL) return Exited(self, exit_code).with_error( - ( - f"Job:{self.name()} has been running " - f"for more than {max_running_minutes} " - "minutes - explicitly killed." - ) + f"Job:{self.name()} has been running " + f"for more than {max_running_minutes} " + "minutes - explicitly killed." ) def _handle_process_io_error_and_create_exited_message( @@ -403,10 +401,8 @@ def _assert_arg_list(self): int(arg_list[index]) except ValueError: errors.append( - ( - f"In job {self.name()}: argument with index {index} " - "is of incorrect type, should be integer." - ) + f"In job {self.name()}: argument with index {index} " + "is of incorrect type, should be integer." ) return errors diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index f4f140232e1..81cbb43e682 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -5,7 +5,7 @@ import threading from datetime import datetime, timedelta from pathlib import Path -from typing import Final, Union +from typing import Final from _ert import events from _ert.events import ( @@ -138,7 +138,7 @@ def _init_handler(self, msg: Init): self._real_id = str(msg.real_id) self._event_publisher_thread.start() - def _job_handler(self, msg: Union[Start, Running, Exited]): + def _job_handler(self, msg: Start | Running | Exited): assert msg.job job_name = msg.job.name() job_msg = { diff --git a/src/_ert/forward_model_runner/reporting/file.py b/src/_ert/forward_model_runner/reporting/file.py index e6e601fe0f2..9a62bc5871a 100644 --- a/src/_ert/forward_model_runner/reporting/file.py +++ b/src/_ert/forward_model_runner/reporting/file.py @@ -179,9 +179,7 @@ def _dump_error_file(fm_step, error_msg): stderr_file = None if fm_step.std_err: if os.path.exists(fm_step.std_err): - with open( - fm_step.std_err, "r", encoding="utf-8" - ) as error_file_handler: + with open(fm_step.std_err, encoding="utf-8") as error_file_handler: stderr = error_file_handler.read() if stderr: stderr_file = os.path.join(os.getcwd(), fm_step.std_err) diff --git a/src/_ert/forward_model_runner/reporting/interactive.py b/src/_ert/forward_model_runner/reporting/interactive.py index fd489c78378..b508d5037f8 100644 --- a/src/_ert/forward_model_runner/reporting/interactive.py +++ b/src/_ert/forward_model_runner/reporting/interactive.py @@ -1,5 +1,3 @@ -from typing import Optional - from _ert.forward_model_runner.reporting.base import Reporter from _ert.forward_model_runner.reporting.message import ( _JOB_EXIT_FAILED_STRING, @@ -11,8 +9,8 @@ class Interactive(Reporter): @staticmethod - def _report(msg: Message) -> Optional[str]: - if not isinstance(msg, (Start, Finish)): + def _report(msg: Message) -> str | None: + if not isinstance(msg, Start | Finish): return None if isinstance(msg, Finish): return ( @@ -27,6 +25,6 @@ def _report(msg: Message) -> Optional[str]: return f"Running job: {msg.job.name()} ... " def report(self, msg: Message): - _msg = self._report(msg) - if _msg is not None: - print(_msg) + msg_ = self._report(msg) + if msg_ is not None: + print(msg_) diff --git a/src/_ert/forward_model_runner/reporting/message.py b/src/_ert/forward_model_runner/reporting/message.py index d07816389f2..a304fede556 100644 --- a/src/_ert/forward_model_runner/reporting/message.py +++ b/src/_ert/forward_model_runner/reporting/message.py @@ -1,6 +1,6 @@ import dataclasses from datetime import datetime as dt -from typing import TYPE_CHECKING, Dict, Literal, Optional +from typing import TYPE_CHECKING, Literal import psutil from typing_extensions import TypedDict @@ -39,17 +39,17 @@ class ProcessTreeStatus: """Holds processtree information that can be represented as a line of CSV data""" timestamp: str = "" - fm_step_id: Optional[int] = None - fm_step_name: Optional[str] = None + fm_step_id: int | None = None + fm_step_name: str | None = None # Memory unit is bytes - rss: Optional[int] = None - max_rss: Optional[int] = None - free: Optional[int] = None + rss: int | None = None + max_rss: int | None = None + free: int | None = None cpu_seconds: float = 0.0 - oom_score: Optional[int] = None + oom_score: int | None = None def __post_init__(self): self.timestamp = dt.now().isoformat() @@ -72,8 +72,8 @@ def __repr__(cls): class Message(metaclass=_MetaMessage): def __init__(self, job=None): self.timestamp = dt.now() - self.job: Optional[ForwardModelStep] = job - self.error_message: Optional[str] = None + self.job: ForwardModelStep | None = job + self.error_message: str | None = None def __repr__(self): return type(self).__name__ @@ -134,7 +134,7 @@ def __init__(self, fm_step, exit_code: int): class Checksum(Message): - def __init__(self, checksum_dict: Dict[str, "ChecksumDict"], run_path: str): + def __init__(self, checksum_dict: dict[str, "ChecksumDict"], run_path: str): super().__init__() self.data = checksum_dict self.run_path = run_path diff --git a/src/_ert/forward_model_runner/reporting/statemachine.py b/src/_ert/forward_model_runner/reporting/statemachine.py index 4d749414e4d..97672cd5b38 100644 --- a/src/_ert/forward_model_runner/reporting/statemachine.py +++ b/src/_ert/forward_model_runner/reporting/statemachine.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Dict, Tuple, Type +from collections.abc import Callable from _ert.forward_model_runner.reporting.message import ( Checksum, @@ -25,7 +25,7 @@ def __init__(self) -> None: jobs = (Start, Running, Exited) checksum = (Checksum,) finished = (Finish,) - self._handler: Dict[Message, Callable[[Message], None]] = {} + self._handler: dict[Message, Callable[[Message], None]] = {} self._transitions = { None: initialized, initialized: jobs + checksum + finished, @@ -35,7 +35,7 @@ def __init__(self) -> None: self._state = None def add_handler( - self, states: Tuple[Type[Message], ...], handler: Callable[[Message], None] + self, states: tuple[type[Message], ...], handler: Callable[[Message], None] ) -> None: if states in self._handler: raise ValueError(f"{states} already handled by {self._handler[states]}") diff --git a/src/_ert/forward_model_runner/runner.py b/src/_ert/forward_model_runner/runner.py index bd304f3c7d3..f1186977b9a 100644 --- a/src/_ert/forward_model_runner/runner.py +++ b/src/_ert/forward_model_runner/runner.py @@ -2,14 +2,14 @@ import json import os from pathlib import Path -from typing import Any, Dict, List +from typing import Any from _ert.forward_model_runner.forward_model_step import ForwardModelStep from _ert.forward_model_runner.reporting.message import Checksum, Finish, Init class ForwardModelRunner: - def __init__(self, steps_data: Dict[str, Any]): + def __init__(self, steps_data: dict[str, Any]): self.steps_data = ( steps_data # On disk, this is called jobs.json for legacy reasons ) @@ -22,7 +22,7 @@ def __init__(self, steps_data: Dict[str, Any]): if self.simulation_id is not None: os.environ["ERT_RUN_ID"] = self.simulation_id - self.steps: List[ForwardModelStep] = [] + self.steps: list[ForwardModelStep] = [] for index, step_data in enumerate(steps_data["jobList"]): self.steps.append(ForwardModelStep(step_data, index)) @@ -31,7 +31,7 @@ def __init__(self, steps_data: Dict[str, Any]): def _read_manifest(self): if not Path("manifest.json").exists(): return None - with open("manifest.json", mode="r", encoding="utf-8") as f: + with open("manifest.json", encoding="utf-8") as f: data = json.load(f) return { name: {"type": "file", "path": str(Path(file).absolute())} @@ -49,7 +49,7 @@ def _populate_checksums(self, manifest): info["error"] = f"Expected file {path} not created by forward model!" return manifest - def run(self, names_of_steps_to_run: List[str]): + def run(self, names_of_steps_to_run: list[str]): if not names_of_steps_to_run: step_queue = self.steps else: diff --git a/src/_ert/threading.py b/src/_ert/threading.py index 69c50b7443c..c759c9cbcc6 100644 --- a/src/_ert/threading.py +++ b/src/_ert/threading.py @@ -5,14 +5,15 @@ import signal import threading import traceback +from collections.abc import Callable, Iterable from threading import Thread as _Thread from types import FrameType -from typing import Any, Callable, Iterable, Optional +from typing import Any logger = logging.getLogger(__name__) -_current_exception: Optional[ErtThreadError] = None +_current_exception: ErtThreadError | None = None _can_raise = False diff --git a/src/ert/__main__.py b/src/ert/__main__.py index 45c892c1cbe..11808d0ed2c 100755 --- a/src/ert/__main__.py +++ b/src/ert/__main__.py @@ -11,7 +11,8 @@ import sys import warnings from argparse import ArgumentParser, ArgumentTypeError -from typing import Any, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any from uuid import UUID import yaml @@ -50,14 +51,14 @@ logger = logging.getLogger(__name__) -def run_ert_storage(args: Namespace, _: Optional[ErtPluginManager] = None) -> None: +def run_ert_storage(args: Namespace, _: ErtPluginManager | None = None) -> None: with StorageService.start_server( verbose=True, project=ErtConfig.from_file(args.config).ens_path ) as server: server.wait() -def run_webviz_ert(args: Namespace, _: Optional[ErtPluginManager] = None) -> None: +def run_webviz_ert(args: Namespace, _: ErtPluginManager | None = None) -> None: try: import webviz_ert # type: ignore # noqa except ImportError as err: @@ -65,7 +66,7 @@ def run_webviz_ert(args: Namespace, _: Optional[ErtPluginManager] = None) -> Non "Running `ert vis` requires that webviz_ert is installed" ) from err - kwargs: Dict[str, Any] = {"verbose": args.verbose} + kwargs: dict[str, Any] = {"verbose": args.verbose} ert_config = ErtConfig.with_plugins().from_file(args.config) os.chdir(ert_config.config_path) ens_path = ert_config.ens_path @@ -141,7 +142,7 @@ def valid_name(user_input: str) -> str: return user_input -def valid_ensemble(user_input: str) -> Union[str, UUID]: +def valid_ensemble(user_input: str) -> str | UUID: if user_input.startswith("UUID="): return UUID(user_input[5:]) return valid_name(user_input) @@ -194,8 +195,8 @@ def run_lint_wrapper(args: Namespace, _: ErtPluginManager) -> None: class DeprecatedAction(argparse.Action): - def __init__(self, alternative_option: Optional[str] = None, **kwargs: Any) -> None: - self.alternative_option: Optional[str] = alternative_option + def __init__(self, alternative_option: str | None = None, **kwargs: Any) -> None: + self.alternative_option: str | None = alternative_option super().__init__(**kwargs) def __call__( @@ -203,7 +204,7 @@ def __call__( parser: ArgumentParser, namespace: argparse.Namespace, values: Any, - option_string: Optional[str] = None, + option_string: str | None = None, ) -> None: alternative_msg: str = ( f"Use {self.alternative_option} instead." if self.alternative_option else "" @@ -215,7 +216,7 @@ def __call__( setattr(namespace, self.dest, values) -def get_ert_parser(parser: Optional[ArgumentParser] = None) -> ArgumentParser: +def get_ert_parser(parser: ArgumentParser | None = None) -> ArgumentParser: if parser is None: parser = ArgumentParser(description="ERT - Ensemble Reservoir Tool") @@ -605,7 +606,7 @@ def get_ert_parser(parser: Optional[ArgumentParser] = None) -> ArgumentParser: return parser -def ert_parser(parser: Optional[ArgumentParser], args: Sequence[str]) -> Namespace: +def ert_parser(parser: ArgumentParser | None, args: Sequence[str]) -> Namespace: return get_ert_parser(parser).parse_args( args, namespace=Namespace(), @@ -617,7 +618,7 @@ def log_process_usage() -> None: usage = resource.getrusage(resource.RUSAGE_SELF) max_rss = ert.shared.status.utils.get_ert_memory_usage() - usage_dict: Dict[str, Union[int, float]] = { + usage_dict: dict[str, int | float] = { "User time": usage.ru_utime, "System time": usage.ru_stime, "File system inputs": usage.ru_inblock, diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index bac5c2042b8..5c96b7c4051 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -3,15 +3,12 @@ import functools import logging import time +from collections.abc import Callable, Iterable, Sequence from fnmatch import fnmatch from typing import ( TYPE_CHECKING, - Callable, Generic, - Iterable, - Optional, Self, - Sequence, TypeVar, ) @@ -165,7 +162,7 @@ def _load_observations_and_responses( global_std_scaling: float, iens_active_index: npt.NDArray[np.int_], selected_observations: Iterable[str], - auto_scale_observations: Optional[list[ObservationGroups]], + auto_scale_observations: list[ObservationGroups] | None, progress_callback: Callable[[AnalysisEvent], None], ) -> tuple[ npt.NDArray[np.float64], @@ -262,7 +259,7 @@ def _load_observations_and_responses( ) ) - if len(scaling_factors_dfs): + if scaling_factors_dfs: scaling_factors_df = polars.concat(scaling_factors_dfs) ensemble.save_observation_scaling_factors(scaling_factors_df) @@ -445,7 +442,7 @@ def analysis_ES( source_ensemble: Ensemble, target_ensemble: Ensemble, progress_callback: Callable[[AnalysisEvent], None], - auto_scale_observations: Optional[list[ObservationGroups]], + auto_scale_observations: list[ObservationGroups] | None, ) -> None: iens_active_index = np.flatnonzero(ens_mask) @@ -574,12 +571,12 @@ def correlation_callback( t["name"] # type: ignore for t in config_node.transform_function_definitions ] - _cross_correlations = np.vstack(cross_correlations) - if _cross_correlations.size != 0: + cross_correlations_ = np.vstack(cross_correlations) + if cross_correlations_.size != 0: source_ensemble.save_cross_correlations( - _cross_correlations, + cross_correlations_, param_group, - parameter_names[: _cross_correlations.shape[0]], + parameter_names[: cross_correlations_.shape[0]], ) logger.info( f"Adaptive Localization of {param_group} completed in {(time.time() - start) / 60} minutes" @@ -623,7 +620,7 @@ def analysis_IES( ens_mask: npt.NDArray[np.bool_], source_ensemble: Ensemble, target_ensemble: Ensemble, - sies_smoother: Optional[ies.SIES], + sies_smoother: ies.SIES | None, progress_callback: Callable[[AnalysisEvent], None], auto_scale_observations: list[ObservationGroups], sies_step_length: Callable[[int], float], @@ -750,10 +747,10 @@ def smoother_update( posterior_storage: Ensemble, observations: Iterable[str], parameters: Iterable[str], - analysis_config: Optional[UpdateSettings] = None, - es_settings: Optional[ESSettings] = None, - rng: Optional[np.random.Generator] = None, - progress_callback: Optional[Callable[[AnalysisEvent], None]] = None, + analysis_config: UpdateSettings | None = None, + es_settings: ESSettings | None = None, + rng: np.random.Generator | None = None, + progress_callback: Callable[[AnalysisEvent], None] | None = None, global_scaling: float = 1.0, ) -> SmootherSnapshot: if not progress_callback: @@ -814,15 +811,15 @@ def smoother_update( def iterative_smoother_update( prior_storage: Ensemble, posterior_storage: Ensemble, - sies_smoother: Optional[ies.SIES], + sies_smoother: ies.SIES | None, parameters: Iterable[str], observations: Iterable[str], update_settings: UpdateSettings, analysis_config: IESSettings, sies_step_length: Callable[[int], float], initial_mask: npt.NDArray[np.bool_], - rng: Optional[np.random.Generator] = None, - progress_callback: Optional[Callable[[AnalysisEvent], None]] = None, + rng: np.random.Generator | None = None, + progress_callback: Callable[[AnalysisEvent], None] | None = None, global_scaling: float = 1.0, ) -> tuple[SmootherSnapshot, ies.SIES]: if not progress_callback: diff --git a/src/ert/analysis/misfit_preprocessor.py b/src/ert/analysis/misfit_preprocessor.py index 82967b72c84..693d211652c 100644 --- a/src/ert/analysis/misfit_preprocessor.py +++ b/src/ert/analysis/misfit_preprocessor.py @@ -1,5 +1,4 @@ import logging -from typing import Tuple import numpy as np import numpy.typing as npt @@ -19,19 +18,15 @@ def get_scaling_factor(nr_observations: int, nr_components: int) -> float: below a user threshold """ logger.info( - ( - f"Calculation scaling factor, nr of primary components: " - f"{nr_components}, number of observations: {nr_observations}" - ) + f"Calculation scaling factor, nr of primary components: " + f"{nr_components}, number of observations: {nr_observations}" ) if nr_components == 0: nr_components = 1 logger.warning( - ( - "Number of PCA components is 0. " - "Setting to 1 to avoid division by zero " - "when calculating scaling factor" - ) + "Number of PCA components is 0. " + "Setting to 1 to avoid division by zero " + "when calculating scaling factor" ) return np.sqrt(nr_observations / float(nr_components)) @@ -72,7 +67,7 @@ def cluster_responses( def main( responses: npt.NDArray[np.float64], obs_errors: npt.NDArray[np.float64], -) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.int_]]: +) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.int_], npt.NDArray[np.int_]]: """ Perform 'Auto Scaling' to mitigate issues with correlated observations in ensemble smoothers. diff --git a/src/ert/analysis/snapshots.py b/src/ert/analysis/snapshots.py index 87b6ae213fe..505b42ab2ef 100644 --- a/src/ert/analysis/snapshots.py +++ b/src/ert/analysis/snapshots.py @@ -1,5 +1,5 @@ from enum import Enum, auto -from typing import Any, Dict, List +from typing import Any import numpy as np from pydantic import BaseModel @@ -53,10 +53,10 @@ class SmootherSnapshot(BaseModel): alpha: float std_cutoff: float global_scaling: float - update_step_snapshots: List[ObservationAndResponseSnapshot] + update_step_snapshots: list[ObservationAndResponseSnapshot] @property - def header(self) -> List[str]: + def header(self) -> list[str]: return [ "Observation name", "Index", @@ -70,7 +70,7 @@ def header(self) -> List[str]: ] @property - def csv(self) -> List[List[Any]]: + def csv(self) -> list[list[Any]]: data = [] for step in self.update_step_snapshots: data.append( @@ -89,7 +89,7 @@ def csv(self) -> List[List[Any]]: return data @property - def extra(self) -> Dict[str, str]: + def extra(self) -> dict[str, str]: return { "Parent ensemble": self.source_ensemble_name, "Target ensemble": self.target_ensemble_name, diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index c0069bd1fc9..10f325de1a3 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -7,7 +7,7 @@ import queue import sys from collections import Counter -from typing import Optional, TextIO +from typing import TextIO from _ert.threading import ErtThread from ert.cli.monitor import Monitor @@ -34,7 +34,7 @@ class ErtCliError(Exception): pass -def run_cli(args: Namespace, plugin_manager: Optional[ErtPluginManager] = None) -> None: +def run_cli(args: Namespace, plugin_manager: ErtPluginManager | None = None) -> None: ert_dir = os.path.abspath(os.path.dirname(args.config)) os.chdir(ert_dir) # Changing current working directory means we need to update @@ -133,7 +133,7 @@ def run_cli(args: Namespace, plugin_manager: Optional[ErtPluginManager] = None) out = sys.stderr monitor = Monitor(out=out, color_always=args.color_always) thread.start() - end_event: Optional[EndEvent] = None + end_event: EndEvent | None = None try: end_event = monitor.monitor( status_queue, ert_config.analysis_config.log_path diff --git a/src/ert/cli/monitor.py b/src/ert/cli/monitor.py index 037ffcc3334..89831e1079b 100644 --- a/src/ert/cli/monitor.py +++ b/src/ert/cli/monitor.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import sys from datetime import datetime, timedelta from pathlib import Path from queue import SimpleQueue -from typing import Dict, Optional, TextIO, Tuple +from typing import TextIO from tqdm import tqdm @@ -31,7 +30,7 @@ ) from ert.shared.status.utils import format_running_time -Color = Tuple[int, int, int] +Color = tuple[int, int, int] def _no_color(text: str, color: Color) -> str: @@ -59,8 +58,8 @@ class Monitor: def __init__(self, out: TextIO = sys.stdout, color_always: bool = False) -> None: self._out = out - self._snapshots: Dict[int, EnsembleSnapshot] = {} - self._start_time: Optional[datetime] = None + self._snapshots: dict[int, EnsembleSnapshot] = {} + self._start_time: datetime | None = None self._colorize = ansi_color # If out is not (like) a tty, disable colors. if not out.isatty() and not color_always: @@ -73,7 +72,7 @@ def __init__(self, out: TextIO = sys.stdout, color_always: bool = False) -> None def monitor( self, event_queue: SimpleQueue[StatusEvents], - output_path: Optional[Path] = None, + output_path: Path | None = None, ) -> EndEvent: self._start_time = datetime.now() while True: @@ -100,7 +99,7 @@ def monitor( event.write_as_csv(output_path) def _print_job_errors(self) -> None: - failed_jobs: Dict[Optional[str], int] = {} + failed_jobs: dict[str | None, int] = {} for snapshot in self._snapshots.values(): for real in snapshot.reals.values(): for job in real["fm_steps"].values(): @@ -118,15 +117,15 @@ def _get_legends(self) -> str: aggregate = latest_snapshot.aggregate_real_states() for state_ in ALL_REALIZATION_STATES: count = aggregate[state_] - _countstring = f"{count}/{total_count}" + countstring = f"{count}/{total_count}" out = ( f"{self._colorize(self.dot, color=REAL_STATE_TO_COLOR[state_])}" - f"{state_:10} {_countstring:>10}" + f"{state_:10} {countstring:>10}" ) statuses += f" {out}\n" return statuses - def _print_result(self, failed: bool, failed_message: Optional[str]) -> None: + def _print_result(self, failed: bool, failed_message: str | None) -> None: if failed: msg = f"Experiment failed with the following error: {failed_message}" print(self._colorize(msg, color=COLOR_FAILED), file=self._out) diff --git a/src/ert/config/_get_num_cpu.py b/src/ert/config/_get_num_cpu.py index 524170597b1..4a6e466db1d 100644 --- a/src/ert/config/_get_num_cpu.py +++ b/src/ert/config/_get_num_cpu.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import Iterator, Optional, TypeVar, overload +from collections.abc import Iterator +from typing import TypeVar, overload from .parsing import ConfigWarning -def get_num_cpu_from_data_file(data_file: str) -> Optional[int]: +def get_num_cpu_from_data_file(data_file: str) -> int | None: """Reads the number of cpus required from the reservoir simulator .data file. Works similarly to resdata.util.get_num_cpu @@ -76,7 +77,7 @@ def get_num_cpu_from_data_file(data_file: str) -> Optional[int]: """ try: - with open(data_file, "r", encoding="utf-8") as file: + with open(data_file, encoding="utf-8") as file: return _get_num_cpu(iter(file), data_file) except (OSError, UnicodeDecodeError) as err: ConfigWarning.warn( @@ -86,8 +87,8 @@ def get_num_cpu_from_data_file(data_file: str) -> Optional[int]: def _get_num_cpu( - lines_iter: Iterator[str], data_file_name: Optional[str] = None -) -> Optional[int]: + lines_iter: Iterator[str], data_file_name: str | None = None +) -> int | None: """Handles reading the lines in the data file and returns the num_cpu TITLE keyword requires skipping one line diff --git a/src/ert/config/_option_dict.py b/src/ert/config/_option_dict.py index ee8287e14ce..75ec081b8a0 100644 --- a/src/ert/config/_option_dict.py +++ b/src/ert/config/_option_dict.py @@ -1,12 +1,12 @@ import logging -from typing import Dict, Sequence +from collections.abc import Sequence from .parsing import ConfigValidationError logger = logging.getLogger(__name__) -def option_dict(option_list: Sequence[str], offset: int) -> Dict[str, str]: +def option_dict(option_list: Sequence[str], offset: int) -> dict[str, str]: """Gets the list of options given to a keywords such as GEN_DATA. The first step of parsing will separate a line such as diff --git a/src/ert/config/_read_summary.py b/src/ert/config/_read_summary.py index eb5a29c00da..6a444e0d9a4 100644 --- a/src/ert/config/_read_summary.py +++ b/src/ert/config/_read_summary.py @@ -4,18 +4,12 @@ import os import os.path import re +from collections.abc import Callable, Sequence from datetime import datetime, timedelta from enum import Enum, auto from typing import ( Any, - Callable, - Dict, - List, - Optional, - Sequence, - Tuple, TypeVar, - Union, ) import numpy as np @@ -30,7 +24,7 @@ def _cell_index( array_index: int, nx: PositiveInt, ny: PositiveInt -) -> Tuple[int, int, int]: +) -> tuple[int, int, int]: k = array_index // (nx * ny) array_index -= k * (nx * ny) j = array_index // nx @@ -43,8 +37,8 @@ def _cell_index( def _check_if_missing( - keyword_name: str, missing_key: str, *test_vars: Optional[T] -) -> List[T]: + keyword_name: str, missing_key: str, *test_vars: T | None +) -> list[T]: if any(v is None for v in test_vars): raise InvalidResponseFile( f"Found {keyword_name} keyword in summary " @@ -55,14 +49,14 @@ def _check_if_missing( def make_summary_key( keyword: str, - number: Optional[int] = None, - name: Optional[str] = None, - nx: Optional[int] = None, - ny: Optional[int] = None, - lgr_name: Optional[str] = None, - li: Optional[int] = None, - lj: Optional[int] = None, - lk: Optional[int] = None, + number: int | None = None, + name: str | None = None, + nx: int | None = None, + ny: int | None = None, + lgr_name: str | None = None, + li: int | None = None, + lj: int | None = None, + lk: int | None = None, ) -> str: try: sum_type = SummaryKeyType.from_keyword(keyword) @@ -125,7 +119,7 @@ def make_delta(self, val: float) -> timedelta: raise InvalidResponseFile(f"Unknown date unit {val}") -def _is_base_with_extension(base: str, path: str, exts: List[str]) -> bool: +def _is_base_with_extension(base: str, path: str, exts: list[str]) -> bool: """ >>> _is_base_with_extension("ECLBASE", "ECLBASE.SMSPEC", ["smspec"]) True @@ -166,7 +160,7 @@ def _find_file_matching( return os.path.join(dir, candidates[0]) -def _get_summary_filenames(filepath: str) -> Tuple[str, str]: +def _get_summary_filenames(filepath: str) -> tuple[str, str]: if filepath.lower().endswith(".data"): # For backwards compatability, it is # allowed to give REFCASE and ECLBASE both @@ -179,7 +173,7 @@ def _get_summary_filenames(filepath: str) -> Tuple[str, str]: def read_summary( filepath: str, fetch_keys: Sequence[str] -) -> Tuple[datetime, List[str], Sequence[datetime], Any]: +) -> tuple[datetime, list[str], Sequence[datetime], Any]: summary, spec = _get_summary_filenames(filepath) try: date_index, start_date, date_units, keys, indices = _read_spec(spec, fetch_keys) @@ -193,14 +187,14 @@ def read_summary( return (start_date, keys, time_map, fetched) -def _key2str(key: Union[bytes, str]) -> str: +def _key2str(key: bytes | str) -> str: ret = key.decode() if isinstance(key, bytes) else key assert isinstance(ret, str) return ret.strip() def _check_vals( - kw: str, spec: str, vals: Union[npt.NDArray[Any], resfo.MESS] + kw: str, spec: str, vals: npt.NDArray[Any] | resfo.MESS ) -> npt.NDArray[Any]: if vals is resfo.MESS or isinstance(vals, resfo.MESS): raise InvalidResponseFile(f"{kw.strip()} in {spec} has incorrect type MESS") @@ -241,14 +235,14 @@ def _fetch_keys_to_matcher(fetch_keys: Sequence[str]) -> Callable[[str], bool]: def _read_spec( spec: str, fetch_keys: Sequence[str] -) -> Tuple[int, datetime, DateUnit, List[str], npt.NDArray[np.int64]]: +) -> tuple[int, datetime, DateUnit, list[str], npt.NDArray[np.int64]]: date = None n = None nx = None ny = None wgnames = None - arrays: Dict[str, Optional[npt.NDArray[Any]]] = dict.fromkeys( + arrays: dict[str, npt.NDArray[Any] | None] = dict.fromkeys( [ "NUMS ", "KEYWORDS", @@ -322,14 +316,14 @@ def _read_spec( if n is None: n = len(keywords) - indices: List[int] = [] - keys: List[str] = [] - index_mapping: Dict[str, int] = {} + indices: list[int] = [] + keys: list[str] = [] + index_mapping: dict[str, int] = {} date_index = None should_load_key = _fetch_keys_to_matcher(fetch_keys) - def optional_get(arr: Optional[npt.NDArray[Any]], idx: int) -> Any: + def optional_get(arr: npt.NDArray[Any] | None, idx: int) -> Any: if arr is None: return None if len(arr) <= idx: @@ -412,7 +406,7 @@ def _read_summary( unit: DateUnit, indices: npt.NDArray[np.int64], date_index: int, -) -> Tuple[npt.NDArray[np.float32], List[datetime]]: +) -> tuple[npt.NDArray[np.float32], list[datetime]]: if summary.lower().endswith("funsmry"): mode = "rt" format = resfo.Format.FORMATTED @@ -421,8 +415,8 @@ def _read_summary( format = resfo.Format.UNFORMATTED last_params = None - values: List[npt.NDArray[np.float32]] = [] - dates: List[datetime] = [] + values: list[npt.NDArray[np.float32]] = [] + dates: list[datetime] = [] def read_params() -> None: nonlocal last_params, values diff --git a/src/ert/config/analysis_config.py b/src/ert/config/analysis_config.py index f1009877f22..3ba2024a35e 100644 --- a/src/ert/config/analysis_config.py +++ b/src/ert/config/analysis_config.py @@ -5,7 +5,7 @@ from math import ceil from os.path import realpath from pathlib import Path -from typing import Any, Dict, Final, List, Optional, Union, no_type_check +from typing import Any, Final, no_type_check from pydantic import PositiveFloat, ValidationError @@ -22,30 +22,30 @@ logger = logging.getLogger(__name__) DEFAULT_ANALYSIS_MODE = AnalysisMode.ENSEMBLE_SMOOTHER -ObservationGroups = List[str] +ObservationGroups = list[str] @dataclass class UpdateSettings: std_cutoff: PositiveFloat = 1e-6 alpha: float = 3.0 - auto_scale_observations: List[ObservationGroups] = field(default_factory=list) + auto_scale_observations: list[ObservationGroups] = field(default_factory=list) @dataclass class AnalysisConfig: - max_runtime: Optional[int] = None + max_runtime: int | None = None minimum_required_realizations: int = 0 - update_log_path: Union[str, Path] = "update_log" + update_log_path: str | Path = "update_log" es_module: ESSettings = field(default_factory=ESSettings) ies_module: IESSettings = field(default_factory=IESSettings) observation_settings: UpdateSettings = field(default_factory=UpdateSettings) num_iterations: int = 1 - design_matrix: Optional[DesignMatrix] = None + design_matrix: DesignMatrix | None = None @no_type_check @classmethod - def from_dict(cls, config_dict: ConfigDict) -> "AnalysisConfig": + def from_dict(cls, config_dict: ConfigDict) -> AnalysisConfig: num_realization: int = config_dict.get(ConfigKeys.NUM_REALIZATIONS, 1) min_realization_str: str = config_dict.get(ConfigKeys.MIN_REALIZATIONS, "0") if "%" in min_realization_str: @@ -83,8 +83,8 @@ def from_dict(cls, config_dict: ConfigDict) -> "AnalysisConfig": design_matrix_config_list = config_dict.get(ConfigKeys.DESIGN_MATRIX, None) - options: Dict[str, Dict[str, Any]] = {"STD_ENKF": {}, "IES_ENKF": {}} - observation_settings: Dict[str, Any] = { + options: dict[str, dict[str, Any]] = {"STD_ENKF": {}, "IES_ENKF": {}} + observation_settings: dict[str, Any] = { "alpha": config_dict.get(ConfigKeys.ENKF_ALPHA, 3.0), "std_cutoff": config_dict.get(ConfigKeys.STD_CUTOFF, 1e-6), "auto_scale_observations": [], diff --git a/src/ert/config/capture_validation.py b/src/ert/config/capture_validation.py index b57bd1d1700..67b2e77166c 100644 --- a/src/ert/config/capture_validation.py +++ b/src/ert/config/capture_validation.py @@ -1,9 +1,10 @@ from __future__ import annotations import logging +from collections.abc import Iterator from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Iterator, cast +from typing import cast from warnings import catch_warnings from .parsing import ConfigValidationError, ConfigWarning, ErrorInfo, WarningInfo diff --git a/src/ert/config/design_matrix.py b/src/ert/config/design_matrix.py index f866766e41c..d4b69ad808f 100644 --- a/src/ert/config/design_matrix.py +++ b/src/ert/config/design_matrix.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -11,10 +11,7 @@ from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition from ._option_dict import option_dict -from .parsing import ( - ConfigValidationError, - ErrorInfo, -) +from .parsing import ConfigValidationError, ErrorInfo if TYPE_CHECKING: from ert.config import ( @@ -31,13 +28,13 @@ class DesignMatrix: default_sheet: str def __post_init__(self) -> None: - self.num_realizations: Optional[int] = None - self.active_realizations: Optional[List[bool]] = None - self.design_matrix_df: Optional[pd.DataFrame] = None - self.parameter_configuration: Optional[Dict[str, ParameterConfig]] = None + self.num_realizations: int | None = None + self.active_realizations: list[bool] | None = None + self.design_matrix_df: pd.DataFrame | None = None + self.parameter_configuration: dict[str, ParameterConfig] | None = None @classmethod - def from_config_list(cls, config_list: List[str]) -> "DesignMatrix": + def from_config_list(cls, config_list: list[str]) -> DesignMatrix: filename = Path(config_list[0]) options = option_dict(config_list, 1) design_sheet = options.get("DESIGN_SHEET") @@ -119,8 +116,8 @@ def read_design_matrix( ) design_matrix_df = design_matrix_df.assign(**defaults_to_use) - parameter_configuration: Dict[str, ParameterConfig] = {} - transform_function_definitions: List[TransformFunctionDefinition] = [] + parameter_configuration: dict[str, ParameterConfig] = {} + transform_function_definitions: list[TransformFunctionDefinition] = [] for parameter in design_matrix_df.columns: transform_function_definitions.append( TransformFunctionDefinition( @@ -150,12 +147,12 @@ def read_design_matrix( @staticmethod def _read_excel( - file_name: Union[Path, str], + file_name: Path | str, sheet_name: str, - usecols: Optional[List[int]] = None, - header: Optional[int] = 0, - skiprows: Optional[int] = None, - dtype: Optional[str] = None, + usecols: list[int] | None = None, + header: int | None = 0, + skiprows: int | None = None, + dtype: str | None = None, ) -> pd.DataFrame: """ Reads an Excel file into a DataFrame, with options to filter columns and rows, @@ -172,7 +169,7 @@ def _read_excel( return df.dropna(axis=1, how="all") @staticmethod - def _validate_design_matrix(design_matrix: pd.DataFrame) -> List[str]: + def _validate_design_matrix(design_matrix: pd.DataFrame) -> list[str]: """ Validate user inputted design matrix :raises: ValueError if design matrix contains empty headers or empty cells @@ -200,10 +197,10 @@ def _validate_design_matrix(design_matrix: pd.DataFrame) -> List[str]: @staticmethod def _read_defaultssheet( - xls_filename: Union[Path, str], + xls_filename: Path | str, defaults_sheetname: str, - existing_parameters: List[str], - ) -> Dict[str, Union[str, float]]: + existing_parameters: list[str], + ) -> dict[str, str | float]: """ Construct a dict of keys and values to be used as defaults from the first two columns in a spreadsheet. Only returns the keys that are @@ -240,7 +237,7 @@ def _read_defaultssheet( } -def convert_to_numeric(x: str) -> Union[str, float]: +def convert_to_numeric(x: str) -> str | float: try: return pd.to_numeric(x) except ValueError: diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index d53de98aec3..8495ff2f48c 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -4,6 +4,7 @@ import logging import os from collections import defaultdict +from collections.abc import Sequence from dataclasses import field from datetime import datetime from os import path @@ -11,15 +12,7 @@ from typing import ( Any, ClassVar, - DefaultDict, - Dict, - List, - Optional, Self, - Sequence, - Tuple, - Type, - Union, no_type_check, overload, ) @@ -90,15 +83,15 @@ def site_config_location() -> str: def create_forward_model_json( context: Substitutions, - forward_model_steps: List[ForwardModelStep], - run_id: Optional[str], + forward_model_steps: list[ForwardModelStep], + run_id: str | None, iens: int = 0, itr: int = 0, - user_config_file: Optional[str] = "", - env_vars: Optional[Dict[str, str]] = None, - env_pr_fm_step: Optional[Dict[str, Dict[str, Any]]] = None, + user_config_file: str | None = "", + env_vars: dict[str, str] | None = None, + env_pr_fm_step: dict[str, dict[str, Any]] | None = None, skip_pre_experiment_validation: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: if env_vars is None: env_vars = {} if env_pr_fm_step is None: @@ -171,7 +164,7 @@ def handle_default(fm_step: ForwardModelStep, arg: str) -> str: config_file = str(config_file_path.name) if config_file_path else "" job_list_errors = [] - job_list: List[ForwardModelStepJSON] = [] + job_list: list[ForwardModelStepJSON] = [] for idx, fm_step in enumerate(forward_model_steps): substituter = Substituter(fm_step) fm_step_json = { @@ -230,14 +223,14 @@ def handle_default(fm_step: ForwardModelStep, arg: str) -> str: def forward_model_data_to_json( substitutions: Substitutions, - forward_model_steps: List[ForwardModelStep], - env_vars: Dict[str, str], - env_pr_fm_step: Optional[Dict[str, Dict[str, Any]]] = None, - user_config_file: Optional[str] = "", - run_id: Optional[str] = None, + forward_model_steps: list[ForwardModelStep], + env_vars: dict[str, str], + env_pr_fm_step: dict[str, dict[str, Any]] | None = None, + user_config_file: str | None = "", + run_id: str | None = None, iens: int = 0, itr: int = 0, - context_env: Optional[Dict[str, str]] = None, + context_env: dict[str, str] | None = None, ): if context_env is None: context_env = {} @@ -259,39 +252,39 @@ def forward_model_data_to_json( class ErtConfig: DEFAULT_ENSPATH: ClassVar[str] = "storage" DEFAULT_RUNPATH_FILE: ClassVar[str] = ".ert_runpath_list" - PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[Dict[str, ForwardModelStep]] = {} - ENV_PR_FM_STEP: ClassVar[Dict[str, Dict[str, Any]]] = {} - ACTIVATE_SCRIPT: Optional[str] = None + PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[dict[str, ForwardModelStep]] = {} + ENV_PR_FM_STEP: ClassVar[dict[str, dict[str, Any]]] = {} + ACTIVATE_SCRIPT: str | None = None substitutions: Substitutions = field(default_factory=Substitutions) ensemble_config: EnsembleConfig = field(default_factory=EnsembleConfig) ens_path: str = DEFAULT_ENSPATH - env_vars: Dict[str, str] = field(default_factory=dict) - random_seed: Optional[int] = None + env_vars: dict[str, str] = field(default_factory=dict) + random_seed: int | None = None analysis_config: AnalysisConfig = field(default_factory=AnalysisConfig) queue_config: QueueConfig = field(default_factory=QueueConfig) - workflow_jobs: Dict[str, WorkflowJob] = field(default_factory=dict) - workflows: Dict[str, Workflow] = field(default_factory=dict) - hooked_workflows: DefaultDict[HookRuntime, List[Workflow]] = field( + workflow_jobs: dict[str, WorkflowJob] = field(default_factory=dict) + workflows: dict[str, Workflow] = field(default_factory=dict) + hooked_workflows: defaultdict[HookRuntime, list[Workflow]] = field( default_factory=lambda: defaultdict(list) ) runpath_file: Path = Path(DEFAULT_RUNPATH_FILE) - ert_templates: List[Tuple[str, str]] = field(default_factory=list) - installed_forward_model_steps: Dict[str, ForwardModelStep] = field( + ert_templates: list[tuple[str, str]] = field(default_factory=list) + installed_forward_model_steps: dict[str, ForwardModelStep] = field( default_factory=dict ) - forward_model_steps: List[ForwardModelStep] = field(default_factory=list) + forward_model_steps: list[ForwardModelStep] = field(default_factory=list) model_config: ModelConfig = field(default_factory=ModelConfig) user_config_file: str = "no_config" config_path: str = field(init=False) - observation_config: List[ - Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]] + observation_config: list[ + tuple[str, HistoryValues | SummaryValues | GenObsValues] ] = field(default_factory=list) enkf_obs: EnkfObs = field(default_factory=EnkfObs) @field_validator("substitutions", mode="before") @classmethod - def convert_to_substitutions(cls, v: Dict[str, str]) -> Substitutions: + def convert_to_substitutions(cls, v: dict[str, str]) -> Substitutions: if isinstance(v, Substitutions): return v return Substitutions(v) @@ -324,17 +317,17 @@ def __post_init__(self) -> None: if self.user_config_file else os.getcwd() ) - self.observations: Dict[str, polars.DataFrame] = self.enkf_obs.datasets + self.observations: dict[str, polars.DataFrame] = self.enkf_obs.datasets @staticmethod def with_plugins( - forward_model_step_classes: Optional[List[Type[ForwardModelStepPlugin]]] = None, - env_pr_fm_step: Optional[Dict[str, Dict[str, Any]]] = None, - ) -> Type["ErtConfig"]: + forward_model_step_classes: list[type[ForwardModelStepPlugin]] | None = None, + env_pr_fm_step: dict[str, dict[str, Any]] | None = None, + ) -> type["ErtConfig"]: if forward_model_step_classes is None: forward_model_step_classes = ErtPluginManager().forward_model_steps - preinstalled_fm_steps: Dict[str, ForwardModelStepPlugin] = {} + preinstalled_fm_steps: dict[str, ForwardModelStepPlugin] = {} for fm_step_subclass in forward_model_step_classes: fm_step = fm_step_subclass() preinstalled_fm_steps[fm_step.name] = fm_step @@ -346,9 +339,9 @@ def with_plugins( class ErtConfigWithPlugins(ErtConfig): PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[ - Dict[str, ForwardModelStepPlugin] + dict[str, ForwardModelStepPlugin] ] = preinstalled_fm_steps - ENV_PR_FM_STEP: ClassVar[Dict[str, Dict[str, Any]]] = env_pr_fm_step + ENV_PR_FM_STEP: ClassVar[dict[str, dict[str, Any]]] = env_pr_fm_step ACTIVATE_SCRIPT = ErtPluginManager().activate_script() assert issubclass(ErtConfigWithPlugins, ErtConfig) @@ -521,7 +514,7 @@ def from_dict(cls, config_dict) -> Self: summary_obs = { obs[1].key for obs in obs_config_content - if isinstance(obs[1], (HistoryValues, SummaryValues)) + if isinstance(obs[1], HistoryValues | SummaryValues) } if summary_obs: summary_keys = ErtConfig._read_summary_keys(config_dict) @@ -578,7 +571,7 @@ def from_dict(cls, config_dict) -> Self: ) @classmethod - def _read_summary_keys(cls, config_dict) -> List[str]: + def _read_summary_keys(cls, config_dict) -> list[str]: return [ item for sublist in config_dict.get(ConfigKeys.SUMMARY, []) @@ -618,7 +611,7 @@ def _log_config_file(cls, config_file: str, config_file_contents: str) -> None: ) @classmethod - def _log_config_dict(cls, content_dict: Dict[str, Any]) -> None: + def _log_config_dict(cls, content_dict: dict[str, Any]) -> None: tmp_dict = content_dict.copy() tmp_dict.pop("FORWARD_MODEL", None) tmp_dict.pop("LOAD_WORKFLOW", None) @@ -707,7 +700,7 @@ def _read_user_config_and_apply_site_config( @staticmethod def check_non_utf_chars(file_path: str) -> None: try: - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: f.read() except UnicodeDecodeError as e: error_words = str(e).split(" ") @@ -728,7 +721,7 @@ def check_non_utf_chars(file_path: str) -> None: ) from e @classmethod - def _read_templates(cls, config_dict) -> List[Tuple[str, str]]: + def _read_templates(cls, config_dict) -> list[tuple[str, str]]: templates = [] if ConfigKeys.DATA_FILE in config_dict and ConfigKeys.ECLBASE in config_dict: # This replicates the behavior of the DATA_FILE implementation @@ -748,7 +741,7 @@ def _read_templates(cls, config_dict) -> List[Tuple[str, str]]: @classmethod def _validate_dict( cls, config_dict, config_file: str - ) -> List[Union[ErrorInfo, ConfigValidationError]]: + ) -> list[ErrorInfo | ConfigValidationError]: errors = [] if ConfigKeys.SUMMARY in config_dict and ConfigKeys.ECLBASE not in config_dict: @@ -764,10 +757,10 @@ def _validate_dict( @classmethod def _create_list_of_forward_model_steps_to_run( cls, - installed_steps: Dict[str, ForwardModelStep], + installed_steps: dict[str, ForwardModelStep], substitutions: Substitutions, config_dict, - ) -> List[ForwardModelStep]: + ) -> list[ForwardModelStep]: errors = [] fm_steps = [] for fm_step_description in config_dict.get(ConfigKeys.FORWARD_MODEL, []): @@ -855,7 +848,7 @@ def _create_list_of_forward_model_steps_to_run( return fm_steps - def forward_model_step_name_list(self) -> List[str]: + def forward_model_step_name_list(self) -> list[str]: return [j.name for j in self.forward_model_steps] @classmethod @@ -976,7 +969,7 @@ def _workflows_from_dict( @classmethod def _installed_forward_model_steps_from_dict( cls, config_dict - ) -> Dict[str, ForwardModelStep]: + ) -> dict[str, ForwardModelStep]: errors = [] fm_steps = {} for fm_step in config_dict.get(ConfigKeys.INSTALL_JOB, []): @@ -1027,21 +1020,20 @@ def preferred_num_cpu(self) -> int: return int(self.substitutions.get(f"<{ConfigKeys.NUM_CPU}>", 1)) @property - def env_pr_fm_step(self) -> Dict[str, Dict[str, Any]]: + def env_pr_fm_step(self) -> dict[str, dict[str, Any]]: return self.ENV_PR_FM_STEP @staticmethod def _create_observations( - obs_config_content: Optional[ - Dict[str, Union[HistoryValues, SummaryValues, GenObsValues]] - ], + obs_config_content: dict[str, HistoryValues | SummaryValues | GenObsValues] + | None, ensemble_config: EnsembleConfig, - time_map: Optional[List[datetime]], + time_map: list[datetime] | None, history: HistorySource, ) -> EnkfObs: if not obs_config_content: return EnkfObs({}, []) - obs_vectors: Dict[str, ObsVector] = {} + obs_vectors: dict[str, ObsVector] = {} obs_time_list: Sequence[datetime] = [] if ensemble_config.refcase is not None: obs_time_list = ensemble_config.refcase.all_dates @@ -1049,7 +1041,7 @@ def _create_observations( obs_time_list = time_map time_len = len(obs_time_list) - config_errors: List[ErrorInfo] = [] + config_errors: list[ErrorInfo] = [] for obs_name, values in obs_config_content: try: if type(values) == HistoryValues: @@ -1143,8 +1135,8 @@ def _substitutions_from_dict(config_dict) -> Substitutions: def _uppercase_subkeys_and_stringify_subvalues( - nested_dict: Dict[str, Dict[str, Any]], -) -> Dict[str, Dict[str, str]]: + nested_dict: dict[str, dict[str, Any]], +) -> dict[str, dict[str, str]]: fixed_dict: dict[str, dict[str, str]] = {} for key, value in nested_dict.items(): fixed_dict[key] = { @@ -1155,7 +1147,7 @@ def _uppercase_subkeys_and_stringify_subvalues( @no_type_check def _forward_model_step_from_config_file( - config_file: str, name: Optional[str] = None + config_file: str, name: str | None = None ) -> "ForwardModelStep": if name is None: name = os.path.basename(config_file) @@ -1165,7 +1157,7 @@ def _forward_model_step_from_config_file( try: content_dict = parse_config(file=config_file, schema=schema, pre_defines=[]) - specified_arg_types: List[Tuple[int, str]] = content_dict.get( + specified_arg_types: list[tuple[int, str]] = content_dict.get( ForwardModelStepKeys.ARG_TYPE, [] ) @@ -1199,5 +1191,5 @@ def _forward_model_step_from_config_file( exec_env=exec_env, default_mapping=default_mapping, ) - except IOError as err: + except OSError as err: raise ConfigValidationError.with_context(str(err), config_file) from err diff --git a/src/ert/config/ert_plugin.py b/src/ert/config/ert_plugin.py index 5f9f328798c..580a9b48f77 100644 --- a/src/ert/config/ert_plugin.py +++ b/src/ert/config/ert_plugin.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC -from typing import Any, List +from typing import Any from .ert_script import ErtScript @@ -11,7 +11,7 @@ class CancelPluginException(Exception): class ErtPlugin(ErtScript, ABC): - def getArguments(self, args: List[Any]) -> List[Any]: + def getArguments(self, args: list[Any]) -> list[Any]: return [] def getName(self) -> str: diff --git a/src/ert/config/ert_script.py b/src/ert/config/ert_script.py index e7acebb9cd9..2248424a7d3 100644 --- a/src/ert/config/ert_script.py +++ b/src/ert/config/ert_script.py @@ -7,8 +7,9 @@ import traceback import warnings from abc import abstractmethod +from collections.abc import Callable from types import MappingProxyType, ModuleType -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any, TypeAlias from typing_extensions import deprecated @@ -16,7 +17,7 @@ from ert.config import ErtConfig from ert.storage import Ensemble, Storage - Fixtures = Union[ErtConfig, Ensemble, Storage] + Fixtures: TypeAlias = ErtConfig | Ensemble | Storage logger = logging.getLogger(__name__) @@ -71,12 +72,12 @@ def stderrdata(self) -> str: return self._stderrdata @deprecated("Use fixtures to the run function instead") - def ert(self) -> Optional[ErtConfig]: + def ert(self) -> ErtConfig | None: logger.info(f"Accessing EnKFMain from workflow: {self.__class__.__name__}") return self._ert @property - def ensemble(self) -> Optional[Ensemble]: + def ensemble(self) -> Ensemble | None: warnings.warn( "The ensemble property is deprecated, use the fixture to the run function instead", DeprecationWarning, @@ -86,7 +87,7 @@ def ensemble(self) -> Optional[Ensemble]: return self._ensemble @property - def storage(self) -> Optional[Storage]: + def storage(self) -> Storage | None: warnings.warn( "The storage property is deprecated, use the fixture to the run function instead", DeprecationWarning, @@ -111,9 +112,9 @@ def cleanup(self) -> None: def initializeAndRun( self, - argument_types: list[Type[Any]], + argument_types: list[type[Any]], argument_values: list[str], - fixtures: Optional[Dict[str, Any]] = None, + fixtures: dict[str, Any] | None = None, ) -> Any: fixtures = {} if fixtures is None else fixtures arguments = [] @@ -169,7 +170,7 @@ def initializeAndRun( def insert_fixtures( self, func_args: MappingProxyType[str, inspect.Parameter], - fixtures: Dict[str, Fixtures], + fixtures: dict[str, Fixtures], ) -> list[Any]: arguments = [] errors = [] @@ -198,7 +199,7 @@ def output_stack_trace(self, error: str = "") -> None: @staticmethod def loadScriptFromFile( path: str, - ) -> Callable[[], "ErtScript"]: + ) -> Callable[[], ErtScript]: module_name = f"ErtScriptModule_{ErtScript.__module_count}" ErtScript.__module_count += 1 @@ -219,7 +220,7 @@ def loadScriptFromFile( @staticmethod def __findErtScriptImplementations( module: ModuleType, - ) -> Callable[[], "ErtScript"]: + ) -> Callable[[], ErtScript]: result = [] for _, member in inspect.getmembers( module, diff --git a/src/ert/config/ext_param_config.py b/src/ert/config/ext_param_config.py index 376acf64d72..cd9378f02c5 100644 --- a/src/ert/config/ext_param_config.py +++ b/src/ert/config/ext_param_config.py @@ -1,9 +1,10 @@ from __future__ import annotations import json +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Mapping, MutableMapping +from typing import TYPE_CHECKING import numpy as np import xarray as xr @@ -116,7 +117,7 @@ def to_dataset(data: DataType) -> xr.Dataset: names: list[str] = [] values: list[float] = [] for outer_key, outer_val in data.items(): - if isinstance(outer_val, (int, float)): + if isinstance(outer_val, int | float): names.append(outer_key) values.append(float(outer_val)) continue diff --git a/src/ert/config/external_ert_script.py b/src/ert/config/external_ert_script.py index 4138e5435a5..732cac7402b 100644 --- a/src/ert/config/external_ert_script.py +++ b/src/ert/config/external_ert_script.py @@ -3,7 +3,7 @@ import codecs import sys from subprocess import PIPE, Popen -from typing import Any, Optional +from typing import Any from .ert_script import ErtScript @@ -13,7 +13,7 @@ def __init__(self, executable: str): super().__init__() self.__executable = executable - self.__job: Optional[Popen[bytes]] = None + self.__job: Popen[bytes] | None = None def run(self, *args: Any) -> None: command = [self.__executable] diff --git a/src/ert/config/field.py b/src/ert/config/field.py index 834cdba899e..0e2a1bbd525 100644 --- a/src/ert/config/field.py +++ b/src/ert/config/field.py @@ -5,7 +5,7 @@ import time from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Self, Union, overload +from typing import TYPE_CHECKING, Any, Self, overload import numpy as np import xarray as xr @@ -33,14 +33,14 @@ class Field(ParameterConfig): ny: int nz: int file_format: FieldFileFormat - output_transformation: Optional[str] - input_transformation: Optional[str] - truncation_min: Optional[float] - truncation_max: Optional[float] + output_transformation: str | None + input_transformation: str | None + truncation_min: float | None + truncation_max: float | None forward_init_file: str output_file: Path grid_file: str - mask_file: Optional[Path] = None + mask_file: Path | None = None @classmethod def from_config_list( @@ -263,29 +263,27 @@ def mask(self) -> Any: @overload def field_transform( - data: xr.DataArray, transform_name: Optional[str] -) -> Union[npt.NDArray[np.float32], xr.DataArray]: + data: xr.DataArray, transform_name: str | None +) -> npt.NDArray[np.float32] | xr.DataArray: pass @overload def field_transform( - data: npt.NDArray[np.float32], transform_name: Optional[str] + data: npt.NDArray[np.float32], transform_name: str | None ) -> npt.NDArray[np.float32]: pass def field_transform( - data: Union[xr.DataArray, npt.NDArray[np.float32]], transform_name: Optional[str] -) -> Union[npt.NDArray[np.float32], xr.DataArray]: + data: xr.DataArray | npt.NDArray[np.float32], transform_name: str | None +) -> npt.NDArray[np.float32] | xr.DataArray: if transform_name is None: return data return TRANSFORM_FUNCTIONS[transform_name](data) # type: ignore -def _field_truncate( - data: npt.ArrayLike, min_: Optional[float], max_: Optional[float] -) -> Any: +def _field_truncate(data: npt.ArrayLike, min_: float | None, max_: float | None) -> Any: if min_ is not None and max_ is not None: vfunc = np.vectorize(lambda x: max(min(x, max_), min_)) return vfunc(data) diff --git a/src/ert/config/forward_model_step.py b/src/ert/config/forward_model_step.py index 37bb561e073..31bcbdba0a6 100644 --- a/src/ert/config/forward_model_step.py +++ b/src/ert/config/forward_model_step.py @@ -3,17 +3,15 @@ import logging from abc import abstractmethod from dataclasses import dataclass, field -from typing import ClassVar, Literal +from typing import ClassVar, Literal, NotRequired from pydantic import field_validator -from typing_extensions import NotRequired, TypedDict, Unpack +from typing_extensions import TypedDict, Unpack from ert.config.parsing.config_errors import ConfigWarning from ert.substitutions import Substitutions -from .parsing import ( - SchemaItemType, -) +from .parsing import SchemaItemType logger = logging.getLogger(__name__) diff --git a/src/ert/config/gen_data_config.py b/src/ert/config/gen_data_config.py index 39c70a6cc71..12d034b85ee 100644 --- a/src/ert/config/gen_data_config.py +++ b/src/ert/config/gen_data_config.py @@ -64,7 +64,7 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Self | None: name, ) try: - _report_steps: list[int] | None = rangestring_to_list( + report_steps_: list[int] | None = rangestring_to_list( report_steps_value ) except ValueError as e: @@ -74,7 +74,7 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Self | None: gen_data, ) from e - _report_steps = sorted(_report_steps) if _report_steps else None + report_steps_ = sorted(report_steps_) if report_steps_ else None if os.path.isabs(res_file): result_file_context = next( x for x in gen_data if x.startswith("RESULT_FILE:") @@ -85,7 +85,7 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Self | None: result_file_context, ) - if _report_steps is None and "%d" in res_file: + if report_steps_ is None and "%d" in res_file: raise ConfigValidationError.from_info( ErrorInfo( message="RESULT_FILES using %d must have REPORT_STEPS:xxxx" @@ -94,19 +94,19 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Self | None: ).set_context_keyword(gen_data) ) - if _report_steps is not None and "%d" not in res_file: + if report_steps_ is not None and "%d" not in res_file: result_file_context = next( x for x in gen_data if x.startswith("RESULT_FILE:") ) raise ConfigValidationError.from_info( ErrorInfo( - message=f"When configuring REPORT_STEPS:{_report_steps} " + message=f"When configuring REPORT_STEPS:{report_steps_} " "RESULT_FILES must be configured using %d" ).set_context_keyword(result_file_context) ) keys.append(name) - report_steps.append(_report_steps) + report_steps.append(report_steps_) input_files.append(res_file) return cls( @@ -141,7 +141,7 @@ def _read_file(filename: Path, report_step: int) -> polars.DataFrame: errors = [] - _run_path = Path(run_path) + run_path_ = Path(run_path) datasets_per_name = [] for name, input_file, report_steps in zip( @@ -151,7 +151,7 @@ def _read_file(filename: Path, report_step: int) -> polars.DataFrame: if report_steps is None: try: filename = substitute_runpath_name(input_file, iens, iter) - datasets_per_report_step.append(_read_file(_run_path / filename, 0)) + datasets_per_report_step.append(_read_file(run_path_ / filename, 0)) except (InvalidResponseFile, FileNotFoundError) as err: errors.append(err) else: @@ -161,7 +161,7 @@ def _read_file(filename: Path, report_step: int) -> polars.DataFrame: ) try: datasets_per_report_step.append( - _read_file(_run_path / filename, report_step) + _read_file(run_path_ / filename, report_step) ) except (InvalidResponseFile, FileNotFoundError) as err: errors.append(err) diff --git a/src/ert/config/gen_kw_config.py b/src/ert/config/gen_kw_config.py index 97f8bc153cb..47a20042874 100644 --- a/src/ert/config/gen_kw_config.py +++ b/src/ert/config/gen_kw_config.py @@ -4,16 +4,11 @@ import os import shutil import warnings +from collections.abc import Callable from dataclasses import dataclass from hashlib import sha256 from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Self, - overload, -) +from typing import TYPE_CHECKING, Any, Self, overload import numpy as np import pandas as pd @@ -165,7 +160,7 @@ def from_config_list(cls, gen_kw: list[str]) -> Self: raise ConfigValidationError.from_collected(errors) transform_function_definitions: list[TransformFunctionDefinition] = [] - with open(parameter_file, "r", encoding="utf-8") as file: + with open(parameter_file, encoding="utf-8") as file: for item in file: item = item.split("--")[0] # remove comments if item.strip(): # only lines with content @@ -308,7 +303,7 @@ def write_to_runpath( template_file_path = ( ensemble.experiment.mount_point / Path(self.template_file).name ) - with open(template_file_path, "r", encoding="utf-8") as f: + with open(template_file_path, encoding="utf-8") as f: template = f.read() for key, value in data.items(): template = template.replace(f"<{key}>", f"{value:.6g}") @@ -351,8 +346,8 @@ def shouldUseLogScale(self, keyword: str) -> bool: return tf.use_log return False - def get_priors(self) -> list["PriorDict"]: - priors: list["PriorDict"] = [] + def get_priors(self) -> list[PriorDict]: + priors: list[PriorDict] = [] for tf in self.transform_functions: priors.append( { @@ -428,8 +423,7 @@ def _sample_value( parameter_values = [] for key in keys: key_hash = sha256( - global_seed.encode("utf-8") - + f"{parameter_group_name}:{key}".encode("utf-8") + global_seed.encode("utf-8") + f"{parameter_group_name}:{key}".encode() ) seed = np.frombuffer(key_hash.digest(), dtype="uint32") rng = np.random.default_rng(seed) @@ -514,16 +508,14 @@ def trans_errf(x: float, arg: list[float]) -> float: Skewness > 0 => Shifts towards the right The width is a relavant scale for the value of skewness. """ - _min, _max, _skew, _width = arg[0], arg[1], arg[2], arg[3] - y = norm(loc=0, scale=_width).cdf(x + _skew) + min_, max_, skew, width = arg[0], arg[1], arg[2], arg[3] + y = norm(loc=0, scale=width).cdf(x + skew) if np.isnan(y): raise ValueError( - ( - "Output is nan, likely from triplet (x, skewness, width) " - "leading to low/high-probability in normal CDF." - ) + "Output is nan, likely from triplet (x, skewness, width) " + "leading to low/high-probability in normal CDF." ) - return _min + y * (_max - _min) + return min_ + y * (max_ - min_) @staticmethod def trans_const(_: float, arg: list[float]) -> float: @@ -539,25 +531,25 @@ def trans_derrf(x: float, arg: list[float]) -> float: Bin the result of `trans_errf` with `min=0` and `max=1` to closest of `nbins` linearly spaced values on [0,1]. Finally map [0,1] to [min, max]. """ - _steps, _min, _max, _skew, _width = ( + steps, min_, max_, skew, width = ( int(arg[0]), arg[1], arg[2], arg[3], arg[4], ) - q_values = np.linspace(start=0, stop=1, num=_steps) - q_checks = np.linspace(start=0, stop=1, num=_steps + 1)[1:] - y = TransformFunction.trans_errf(x, [0, 1, _skew, _width]) + q_values = np.linspace(start=0, stop=1, num=steps) + q_checks = np.linspace(start=0, stop=1, num=steps + 1)[1:] + y = TransformFunction.trans_errf(x, [0, 1, skew, width]) bin_index = np.digitize(y, q_checks, right=True) y_binned = q_values[bin_index] - result = _min + y_binned * (_max - _min) - if result > _max or result < _min: + result = min_ + y_binned * (max_ - min_) + if result > max_ or result < min_: warnings.warn( "trans_derff suffered from catastrophic loss of precision, clamping to min,max", stacklevel=1, ) - return np.clip(result, _min, _max) + return np.clip(result, min_, max_) if np.isnan(result): raise ValueError( "trans_derrf returns nan, check that input arguments are reasonable" @@ -566,52 +558,52 @@ def trans_derrf(x: float, arg: list[float]) -> float: @staticmethod def trans_unif(x: float, arg: list[float]) -> float: - _min, _max = arg[0], arg[1] + min_, max_ = arg[0], arg[1] y = norm.cdf(x) - return y * (_max - _min) + _min + return y * (max_ - min_) + min_ @staticmethod def trans_dunif(x: float, arg: list[float]) -> float: - _steps, _min, _max = int(arg[0]), arg[1], arg[2] + steps, min_, max_ = int(arg[0]), arg[1], arg[2] y = norm.cdf(x) - return (math.floor(y * _steps) / (_steps - 1)) * (_max - _min) + _min + return (math.floor(y * steps) / (steps - 1)) * (max_ - min_) + min_ @staticmethod def trans_normal(x: float, arg: list[float]) -> float: - _mean, _std = arg[0], arg[1] - return x * _std + _mean + mean, std = arg[0], arg[1] + return x * std + mean @staticmethod def trans_truncated_normal(x: float, arg: list[float]) -> float: - _mean, _std, _min, _max = arg[0], arg[1], arg[2], arg[3] - y = x * _std + _mean - return max(min(y, _max), _min) # clamp + mean, std, min_, max_ = arg[0], arg[1], arg[2], arg[3] + y = x * std + mean + return max(min(y, max_), min_) # clamp @staticmethod def trans_lognormal(x: float, arg: list[float]) -> float: # mean is the expectation of log( y ) - _mean, _std = arg[0], arg[1] - return math.exp(x * _std + _mean) + mean, std = arg[0], arg[1] + return math.exp(x * std + mean) @staticmethod def trans_logunif(x: float, arg: list[float]) -> float: - _log_min, _log_max = math.log(arg[0]), math.log(arg[1]) + log_min, log_max = math.log(arg[0]), math.log(arg[1]) tmp = norm.cdf(x) - log_y = _log_min + tmp * (_log_max - _log_min) # Shift according to max / min + log_y = log_min + tmp * (log_max - log_min) # Shift according to max / min return math.exp(log_y) @staticmethod def trans_triangular(x: float, arg: list[float]) -> float: - _min, _mode, _max = arg[0], arg[1], arg[2] - inv_norm_left = (_max - _min) * (_mode - _min) - inv_norm_right = (_max - _min) * (_max - _mode) - ymode = (_mode - _min) / (_max - _min) + min_, mode, max_ = arg[0], arg[1], arg[2] + inv_norm_left = (max_ - min_) * (mode - min_) + inv_norm_right = (max_ - min_) * (max_ - mode) + ymode = (mode - min_) / (max_ - min_) y = norm.cdf(x) if y < ymode: - return _min + math.sqrt(y * inv_norm_left) + return min_ + math.sqrt(y * inv_norm_left) else: - return _max - math.sqrt((1 - y) * inv_norm_right) + return max_ - math.sqrt((1 - y) * inv_norm_right) def calculate(self, x: float, arg: list[float]) -> float: return self.calc_func(x, arg) diff --git a/src/ert/config/general_observation.py b/src/ert/config/general_observation.py index 40744641f7a..ec9858ae4f4 100644 --- a/src/ert/config/general_observation.py +++ b/src/ert/config/general_observation.py @@ -1,17 +1,16 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List import numpy as np @dataclass(eq=False) class GenObservation: - values: List[float] - stds: List[float] - indices: List[int] - std_scaling: List[float] + values: list[float] + stds: list[float] + indices: list[int] + std_scaling: list[float] def __post_init__(self) -> None: for val in self.stds: diff --git a/src/ert/config/model_config.py b/src/ert/config/model_config.py index 49db632a331..25aa53a867d 100644 --- a/src/ert/config/model_config.py +++ b/src/ert/config/model_config.py @@ -31,7 +31,7 @@ def str_to_datetime(date_str: str) -> datetime: return datetime.strptime(date_str, "%d/%m/%Y") dates = [] - with open(file_name, "r", encoding="utf-8") as fin: + with open(file_name, encoding="utf-8") as fin: for line in fin: dates.append(str_to_datetime(line.strip())) return dates @@ -102,7 +102,7 @@ def transform(cls, eclbase_format_string: str) -> str: @no_type_check @classmethod - def from_dict(cls, config_dict: ConfigDict) -> "ModelConfig": + def from_dict(cls, config_dict: ConfigDict) -> ModelConfig: time_map_file = config_dict.get(ConfigKeys.TIME_MAP) time_map_file = ( os.path.abspath(time_map_file) if time_map_file is not None else None @@ -111,7 +111,7 @@ def from_dict(cls, config_dict: ConfigDict) -> "ModelConfig": if time_map_file is not None: try: time_map = _read_time_map(time_map_file) - except (ValueError, IOError) as err: + except (OSError, ValueError) as err: raise ConfigValidationError.with_context( f"Could not read timemap file {time_map_file}: {err}", time_map_file ) from err diff --git a/src/ert/config/observation_vector.py b/src/ert/config/observation_vector.py index 9044ce116eb..371eab9d396 100644 --- a/src/ert/config/observation_vector.py +++ b/src/ert/config/observation_vector.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Iterable, List, Union +from typing import TYPE_CHECKING import numpy as np @@ -20,16 +21,16 @@ class ObsVector: observation_type: EnkfObservationImplementationType observation_key: str data_key: str - observations: Dict[Union[int, datetime], Union[GenObservation, SummaryObservation]] + observations: dict[int | datetime, GenObservation | SummaryObservation] - def __iter__(self) -> Iterable[Union[SummaryObservation, GenObservation]]: + def __iter__(self) -> Iterable[SummaryObservation | GenObservation]: """Iterate over active report steps; return node""" return iter(self.observations.values()) def __len__(self) -> int: return len(self.observations) - def to_dataset(self, active_list: List[int]) -> polars.DataFrame: + def to_dataset(self, active_list: list[int]) -> polars.DataFrame: if self.observation_type == EnkfObservationImplementationType.GEN_OBS: dataframes = [] for time_step, node in self.observations.items(): diff --git a/src/ert/config/observations.py b/src/ert/config/observations.py index d5d14260585..d3610a785cd 100644 --- a/src/ert/config/observations.py +++ b/src/ert/config/observations.py @@ -1,8 +1,9 @@ import os +from collections.abc import Iterator from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path -from typing import TYPE_CHECKING, Iterator +from typing import TYPE_CHECKING import numpy as np import polars diff --git a/src/ert/config/parsing/config_dict.py b/src/ert/config/parsing/config_dict.py index f9b278180dd..d7547544226 100644 --- a/src/ert/config/parsing/config_dict.py +++ b/src/ert/config/parsing/config_dict.py @@ -1,6 +1,4 @@ -from typing import Dict - from .context_values import ContextString from .types import MaybeWithContext -ConfigDict = Dict[ContextString, MaybeWithContext] +ConfigDict = dict[ContextString, MaybeWithContext] diff --git a/src/ert/config/parsing/config_errors.py b/src/ert/config/parsing/config_errors.py index d6cafc59499..1e5de728748 100644 --- a/src/ert/config/parsing/config_errors.py +++ b/src/ert/config/parsing/config_errors.py @@ -1,7 +1,8 @@ from __future__ import annotations import warnings -from typing import Callable, Self, Sequence, Type +from collections.abc import Callable, Sequence +from typing import Self from .error_info import ErrorInfo, WarningInfo from .types import MaybeWithContext @@ -28,7 +29,7 @@ def _formatted_warn(cls, config_warning: ConfigWarning) -> None: def ert_formatted_warning( message: Warning | str, - category: Type[Warning], + category: type[Warning], filename: str, lineno: int, line: str | None = None, diff --git a/src/ert/config/parsing/config_schema_deprecations.py b/src/ert/config/parsing/config_schema_deprecations.py index 33a0ece4859..f73bf9997ae 100644 --- a/src/ert/config/parsing/config_schema_deprecations.py +++ b/src/ert/config/parsing/config_schema_deprecations.py @@ -1,5 +1,5 @@ from functools import partial -from typing import List, cast +from typing import cast from .deprecation_info import DeprecationInfo @@ -228,6 +228,6 @@ message="Memory requirements in LSF should now be set using REALIZATION_MEMORY and not" " through the LSF_RESOURCE option.", check=lambda line: "LSF_RESOURCE" in line - and "mem=" in cast(List[str], line)[-1], + and "mem=" in cast(list[str], line)[-1], ), ] diff --git a/src/ert/config/parsing/config_schema_item.py b/src/ert/config/parsing/config_schema_item.py index afc66f7c752..e6c5e38ee31 100644 --- a/src/ert/config/parsing/config_schema_item.py +++ b/src/ert/config/parsing/config_schema_item.py @@ -1,7 +1,8 @@ import os import shutil +from collections.abc import Mapping from enum import EnumType -from typing import List, Mapping, Optional, TypeVar, Union +from typing import TypeVar from pydantic import ConfigDict, Field, NonNegativeInt, PositiveInt from pydantic.dataclasses import dataclass @@ -31,15 +32,15 @@ class SchemaItem: # The minimum number of arguments argc_min: NonNegativeInt = 1 # The maximum number of arguments: None means no upper limit - argc_max: Optional[NonNegativeInt] = 1 + argc_max: NonNegativeInt | None = 1 # A list of types for the items. Set along with argc_minmax() - type_map: List[Union[SchemaItemType, EnumType, None]] = Field(default_factory=list) + type_map: list[SchemaItemType | EnumType | None] = Field(default_factory=list) # A list of item's which must also be set (if this item is set). (can be NULL) - required_children: List[str] = Field(default_factory=list) + required_children: list[str] = Field(default_factory=list) # Information about the deprecation if deprecated - deprecation_info: List[DeprecationInfo] = Field(default_factory=list) + deprecation_info: list[DeprecationInfo] = Field(default_factory=list) # if positive, arguments after this count will be concatenated with a " " between - join_after: Optional[PositiveInt] = None + join_after: PositiveInt | None = None # if true, will accumulate many values set for key, otherwise each entry will # overwrite any previous value set multi_occurrence: bool = False @@ -47,7 +48,7 @@ class SchemaItem: # Index of tokens to do substitution from until end substitute_from: NonNegativeInt = 1 required_set: bool = False - required_children_value: Mapping[str, List[str]] = Field(default_factory=dict) + required_children_value: Mapping[str, list[str]] = Field(default_factory=dict) @classmethod def deprecated_dummy_keyword(cls, info: DeprecationInfo) -> "SchemaItem": @@ -62,7 +63,7 @@ def deprecated_dummy_keyword(cls, info: DeprecationInfo) -> "SchemaItem": def token_to_value_with_context( self, token: FileContextToken, index: int, keyword: FileContextToken, cwd: str - ) -> Optional[MaybeWithContext]: + ) -> MaybeWithContext | None: """ Converts a FileContextToken to a value with context that behaves like a value, but also contains its location in the file, @@ -140,7 +141,7 @@ def token_to_value_with_context( ) case SchemaItemType.PATH | SchemaItemType.EXISTING_PATH: - path: Optional[str] = str(token) + path: str | None = str(token) if not os.path.isabs(token): path = os.path.normpath( os.path.join(os.path.dirname(token.filename), token) @@ -156,7 +157,7 @@ def token_to_value_with_context( assert isinstance(path, str) return ContextString(path, token, keyword) case SchemaItemType.EXECUTABLE: - absolute_path: Optional[str] + absolute_path: str | None is_command = False if not os.path.isabs(token): # Try relative @@ -207,15 +208,13 @@ def token_to_value_with_context( def apply_constraints( self, - args: List[T], + args: list[T], keyword: FileContextToken, cwd: str, - ) -> Union[ - T, MaybeWithContext, ContextList[Union[T, MaybeWithContext, None]], None - ]: - errors: List[Union[ErrorInfo, ConfigValidationError]] = [] + ) -> T | MaybeWithContext | ContextList[T | MaybeWithContext | None] | None: + errors: list[ErrorInfo | ConfigValidationError] = [] - args_with_context: ContextList[Union[T, MaybeWithContext, None]] = ContextList( + args_with_context: ContextList[T | MaybeWithContext | None] = ContextList( token=keyword ) for i, x in enumerate(args): @@ -253,7 +252,7 @@ def apply_constraints( return args_with_context - def join_args(self, line: List[FileContextToken]) -> List[FileContextToken]: + def join_args(self, line: list[FileContextToken]) -> list[FileContextToken]: n = self.join_after if n is not None and n < len(line): joined = FileContextToken.join_tokens(line[n:], " ") diff --git a/src/ert/config/parsing/context_values.py b/src/ert/config/parsing/context_values.py index 9b7c39416d0..1158b0af1e5 100644 --- a/src/ert/config/parsing/context_values.py +++ b/src/ert/config/parsing/context_values.py @@ -1,5 +1,5 @@ from json import JSONEncoder -from typing import Any, List, TypeVar, Union, no_type_check +from typing import Any, TypeVar, no_type_check from .file_context_token import FileContextToken @@ -87,7 +87,7 @@ def __deepcopy__(self, memo): T = TypeVar("T") -class ContextList(List[T]): +class ContextList(list[T]): keyword_token: FileContextToken def __init__(self, token: FileContextToken) -> None: @@ -96,11 +96,11 @@ def __init__(self, token: FileContextToken) -> None: @classmethod def with_values( - cls, token: FileContextToken, values: List["ContextValue"] + cls, token: FileContextToken, values: list["ContextValue"] ) -> "ContextList[ContextValue]": - the_list: "ContextList[ContextValue]" = ContextList(token) + the_list: ContextList[ContextValue] = ContextList(token) the_list += values return the_list -ContextValue = Union[ContextString, ContextFloat, ContextInt, ContextBool] +ContextValue = ContextString | ContextFloat | ContextInt | ContextBool diff --git a/src/ert/config/parsing/deprecation_info.py b/src/ert/config/parsing/deprecation_info.py index 86fd1d1f5f5..4bf4db51ab1 100644 --- a/src/ert/config/parsing/deprecation_info.py +++ b/src/ert/config/parsing/deprecation_info.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable, List, Optional, Union from .context_values import ContextValue @@ -7,10 +7,10 @@ @dataclass class DeprecationInfo: keyword: str - message: Union[str, Callable[[List[str]], str]] - check: Optional[Callable[[List[ContextValue]], bool]] = None + message: str | Callable[[list[str]], str] + check: Callable[[list[ContextValue]], bool] | None = None - def resolve_message(self, line: List[str]) -> str: + def resolve_message(self, line: list[str]) -> str: if callable(self.message): return self.message(line) diff --git a/src/ert/config/parsing/error_info.py b/src/ert/config/parsing/error_info.py index c817e1a2081..373b850a69e 100644 --- a/src/ert/config/parsing/error_info.py +++ b/src/ert/config/parsing/error_info.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from dataclasses import dataclass -from typing import Self, Sequence +from typing import Self from .file_context_token import FileContextToken from .types import MaybeWithContext diff --git a/src/ert/config/parsing/file_context_token.py b/src/ert/config/parsing/file_context_token.py index c9aca00dc09..5079a8c694c 100644 --- a/src/ert/config/parsing/file_context_token.py +++ b/src/ert/config/parsing/file_context_token.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import cast from lark import Token @@ -11,7 +11,7 @@ class FileContextToken(Token): filename: str def __new__(cls, token: Token, filename: str) -> "FileContextToken": - inst = super(FileContextToken, cls).__new__( + inst = super().__new__( cls, token.type, token.value, @@ -46,7 +46,7 @@ def __hash__(self) -> int: # type: ignore @classmethod def join_tokens( - cls, tokens: List["FileContextToken"], separator: str = " " + cls, tokens: list["FileContextToken"], separator: str = " " ) -> "FileContextToken": first = tokens[0] min_start_pos = min(x.start_pos for x in tokens if x.start_pos is not None) diff --git a/src/ert/config/parsing/forward_model_schema.py b/src/ert/config/parsing/forward_model_schema.py index 4cab3ee939c..9f2e659719d 100644 --- a/src/ert/config/parsing/forward_model_schema.py +++ b/src/ert/config/parsing/forward_model_schema.py @@ -1,5 +1,3 @@ -from typing import List - from .config_dict import ConfigDict from .config_schema_item import SchemaItem from .deprecation_info import DeprecationInfo @@ -139,7 +137,7 @@ def default_keyword() -> SchemaItem: ) -forward_model_schema_items: List[SchemaItem] = [ +forward_model_schema_items: list[SchemaItem] = [ executable_keyword(), stdin_keyword(), stdout_keyword(), @@ -158,7 +156,7 @@ def default_keyword() -> SchemaItem: exec_env_keyword(), ] -forward_model_deprecations: List[DeprecationInfo] = [ +forward_model_deprecations: list[DeprecationInfo] = [ DeprecationInfo( keyword="PORTABLE_EXE", message='"PORTABLE_EXE" key is deprecated, please replace with "EXECUTABLE"', diff --git a/src/ert/config/parsing/observations_parser.py b/src/ert/config/parsing/observations_parser.py index 07089d0869e..b06c505b63e 100644 --- a/src/ert/config/parsing/observations_parser.py +++ b/src/ert/config/parsing/observations_parser.py @@ -1,11 +1,11 @@ import os from collections import Counter +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum, auto from typing import ( Any, Literal, - Sequence, cast, no_type_check, ) diff --git a/src/ert/config/parsing/schema_dict.py b/src/ert/config/parsing/schema_dict.py index 69d1bceac4a..f90d8245d0f 100644 --- a/src/ert/config/parsing/schema_dict.py +++ b/src/ert/config/parsing/schema_dict.py @@ -1,6 +1,6 @@ import abc from collections import UserDict -from typing import List, Set, no_type_check +from typing import no_type_check from .config_dict import ConfigDict from .config_errors import ConfigValidationError, ConfigWarning @@ -23,7 +23,7 @@ def search_for_unset_required_keywords( # both with the same value # which causes .values() to return the NUM_REALIZATIONS keyword twice # which again leads to duplicate collection of errors related to this - visited: Set[str] = set() + visited: set[str] = set() for constraints in self.values(): if constraints.kw in visited: @@ -42,7 +42,7 @@ def search_for_unset_required_keywords( if errors: raise ConfigValidationError.from_collected(errors) - def add_deprecations(self, deprecated_keywords_list: List[DeprecationInfo]) -> None: + def add_deprecations(self, deprecated_keywords_list: list[DeprecationInfo]) -> None: for info in deprecated_keywords_list: # Add it to the schema only so that it is # catched by the parser @@ -59,7 +59,7 @@ def search_for_deprecated_keyword_usages( ) -> None: detected_deprecations = [] - def push_deprecation(infos: List[DeprecationInfo], line: List[ContextString]): + def push_deprecation(infos: list[DeprecationInfo], line: list[ContextString]): for info in infos: if info.check is None or (callable(info.check) and info.check(line)): detected_deprecations.append((info, line)) diff --git a/src/ert/config/parsing/types.py b/src/ert/config/parsing/types.py index 3a94efef72e..682fb695f84 100644 --- a/src/ert/config/parsing/types.py +++ b/src/ert/config/parsing/types.py @@ -1,16 +1,16 @@ from enum import Enum -from typing import Any, List, Tuple, Union +from typing import Any from .context_values import ContextValue from .file_context_token import FileContextToken # The type of the leaf nodes in the Tree after transformation is done -Instruction = List[ - List[Union[FileContextToken, List[Tuple[FileContextToken, FileContextToken]]]] +Instruction = list[ + list[FileContextToken | list[tuple[FileContextToken, FileContextToken]]] ] -Defines = List[List[str]] +Defines = list[list[str]] -Primitives = Union[float, bool, str, int, Enum] +Primitives = float | bool | str | int | Enum -MaybeWithContext = Union[ContextValue, Primitives, FileContextToken, List[Any]] +MaybeWithContext = ContextValue | Primitives | FileContextToken | list[Any] diff --git a/src/ert/config/queue_config.py b/src/ert/config/queue_config.py index af55314cb5f..4e5ceba05d0 100644 --- a/src/ert/config/queue_config.py +++ b/src/ert/config/queue_config.py @@ -5,8 +5,9 @@ import re import shutil from abc import abstractmethod +from collections.abc import Mapping from dataclasses import asdict, field, fields -from typing import Annotated, Any, Literal, Mapping, Optional, no_type_check +from typing import Annotated, Any, Literal, no_type_check import pydantic from pydantic.dataclasses import dataclass @@ -38,7 +39,7 @@ class QueueOptions: name: str max_running: pydantic.NonNegativeInt = 0 submit_sleep: pydantic.NonNegativeFloat = 0.0 - project_code: Optional[str] = None + project_code: str | None = None activate_script: str = field(default_factory=activate_script) @staticmethod @@ -46,7 +47,7 @@ def create_queue_options( queue_system: QueueSystem, options: dict[str, Any], is_selected_queue_system: bool, - ) -> Optional[QueueOptions]: + ) -> QueueOptions | None: lower_case_options = {key.lower(): value for key, value in options.items()} try: if queue_system == QueueSystem.LSF: @@ -99,13 +100,13 @@ def driver_options(self) -> dict[str, Any]: @pydantic.dataclasses.dataclass class LsfQueueOptions(QueueOptions): name: Literal[QueueSystem.LSF] = QueueSystem.LSF - bhist_cmd: Optional[NonEmptyString] = None - bjobs_cmd: Optional[NonEmptyString] = None - bkill_cmd: Optional[NonEmptyString] = None - bsub_cmd: Optional[NonEmptyString] = None - exclude_host: Optional[str] = None - lsf_queue: Optional[NonEmptyString] = None - lsf_resource: Optional[str] = None + bhist_cmd: NonEmptyString | None = None + bjobs_cmd: NonEmptyString | None = None + bkill_cmd: NonEmptyString | None = None + bsub_cmd: NonEmptyString | None = None + exclude_host: str | None = None + lsf_queue: NonEmptyString | None = None + lsf_resource: str | None = None @property def driver_options(self) -> dict[str, Any]: @@ -122,19 +123,19 @@ def driver_options(self) -> dict[str, Any]: @pydantic.dataclasses.dataclass class TorqueQueueOptions(QueueOptions): name: Literal[QueueSystem.TORQUE] = QueueSystem.TORQUE - qsub_cmd: Optional[NonEmptyString] = None - qstat_cmd: Optional[NonEmptyString] = None - qdel_cmd: Optional[NonEmptyString] = None - queue: Optional[NonEmptyString] = None - memory_per_job: Optional[NonEmptyString] = None + qsub_cmd: NonEmptyString | None = None + qstat_cmd: NonEmptyString | None = None + qdel_cmd: NonEmptyString | None = None + queue: NonEmptyString | None = None + memory_per_job: NonEmptyString | None = None num_cpus_per_node: pydantic.PositiveInt = 1 num_nodes: pydantic.PositiveInt = 1 - cluster_label: Optional[NonEmptyString] = None - job_prefix: Optional[NonEmptyString] = None + cluster_label: NonEmptyString | None = None + job_prefix: NonEmptyString | None = None keep_qsub_output: bool = False - qstat_options: Optional[str] = pydantic.Field(default=None, deprecated=True) - queue_query_timeout: Optional[str] = pydantic.Field(default=None, deprecated=True) + qstat_options: str | None = pydantic.Field(default=None, deprecated=True) + queue_query_timeout: str | None = pydantic.Field(default=None, deprecated=True) @property def driver_options(self) -> dict[str, Any]: @@ -149,7 +150,7 @@ def driver_options(self) -> dict[str, Any]: @pydantic.field_validator("memory_per_job") @classmethod - def check_memory_per_job(cls, value: Optional[str]) -> Optional[str]: + def check_memory_per_job(cls, value: str | None) -> str | None: if not queue_memory_usage_formats[QueueSystem.TORQUE].validate(value): raise ValueError("wrong memory format") return value @@ -165,11 +166,11 @@ class SlurmQueueOptions(QueueOptions): squeue: NonEmptyString = "squeue" exclude_host: str = "" include_host: str = "" - memory: Optional[NonEmptyString] = None - memory_per_cpu: Optional[NonEmptyString] = None - partition: Optional[NonEmptyString] = None # aka queue_name + memory: NonEmptyString | None = None + memory_per_cpu: NonEmptyString | None = None + partition: NonEmptyString | None = None # aka queue_name squeue_timeout: pydantic.PositiveFloat = 2 - max_runtime: Optional[pydantic.NonNegativeFloat] = None + max_runtime: pydantic.NonNegativeFloat | None = None @property def driver_options(self) -> dict[str, Any]: @@ -189,7 +190,7 @@ def driver_options(self) -> dict[str, Any]: @pydantic.field_validator("memory", "memory_per_cpu") @classmethod - def check_memory_per_job(cls, value: Optional[str]) -> Optional[str]: + def check_memory_per_job(cls, value: str | None) -> str | None: if not queue_memory_usage_formats[QueueSystem.SLURM].validate(value): raise ValueError("wrong memory format") return value @@ -199,7 +200,7 @@ def check_memory_per_job(cls, value: Optional[str]) -> Optional[str]: class QueueMemoryStringFormat: suffixes: list[str] - def validate(self, mem_str_format: Optional[str]) -> bool: + def validate(self, mem_str_format: str | None) -> bool: if mem_str_format is None: return True return ( @@ -298,32 +299,30 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig: max_submit: int = config_dict.get(ConfigKeys.MAX_SUBMIT, 1) stop_long_running = config_dict.get(ConfigKeys.STOP_LONG_RUNNING, False) - _raw_queue_options = config_dict.get("QUEUE_OPTION", []) - _grouped_queue_options = _group_queue_options_by_queue_system( - _raw_queue_options - ) - _log_duplicated_queue_options(_raw_queue_options) - _raise_for_defaulted_invalid_options(_raw_queue_options) + raw_queue_options = config_dict.get("QUEUE_OPTION", []) + grouped_queue_options = _group_queue_options_by_queue_system(raw_queue_options) + _log_duplicated_queue_options(raw_queue_options) + _raise_for_defaulted_invalid_options(raw_queue_options) - _all_validated_queue_options = { + all_validated_queue_options = { selected_queue_system: QueueOptions.create_queue_options( selected_queue_system, - _grouped_queue_options[selected_queue_system], + grouped_queue_options[selected_queue_system], True, ) } - _all_validated_queue_options.update( + all_validated_queue_options.update( { _queue_system: QueueOptions.create_queue_options( - _queue_system, _grouped_queue_options[_queue_system], False + _queue_system, grouped_queue_options[_queue_system], False ) for _queue_system in QueueSystem if _queue_system != selected_queue_system } ) - queue_options = _all_validated_queue_options[selected_queue_system] - queue_options_test_run = _all_validated_queue_options[QueueSystem.LOCAL] + queue_options = all_validated_queue_options[selected_queue_system] + queue_options_test_run = all_validated_queue_options[QueueSystem.LOCAL] queue_options.add_global_queue_options(config_dict) if queue_options.project_code is None: @@ -337,10 +336,10 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig: if selected_queue_system == QueueSystem.TORQUE: _check_num_cpu_requirement( - config_dict.get("NUM_CPU", 1), queue_options, _raw_queue_options + config_dict.get("NUM_CPU", 1), queue_options, raw_queue_options ) - for _queue_vals in _all_validated_queue_options.values(): + for _queue_vals in all_validated_queue_options.values(): if ( isinstance(_queue_vals, TorqueQueueOptions) and _queue_vals.memory_per_job diff --git a/src/ert/config/refcase.py b/src/ert/config/refcase.py index 28e80ad2ff5..4318c3e1c33 100644 --- a/src/ert/config/refcase.py +++ b/src/ert/config/refcase.py @@ -1,6 +1,7 @@ +from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime -from typing import Self, Sequence +from typing import Self import numpy as np diff --git a/src/ert/config/response_config.py b/src/ert/config/response_config.py index 814be4a522b..48ccce947b9 100644 --- a/src/ert/config/response_config.py +++ b/src/ert/config/response_config.py @@ -1,6 +1,6 @@ import dataclasses from abc import ABC, abstractmethod -from typing import Any, Optional, Self +from typing import Any, Self import polars @@ -58,7 +58,7 @@ def primary_key(self) -> list[str]: @classmethod @abstractmethod - def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]: + def from_config_dict(cls, config_dict: ConfigDict) -> Self | None: """Creates a config, given an ert config dict. A response config may depend on several config kws, such as REFCASE for summary.""" diff --git a/src/ert/config/responses_index.py b/src/ert/config/responses_index.py index 20913a39b4e..4d881696b01 100644 --- a/src/ert/config/responses_index.py +++ b/src/ert/config/responses_index.py @@ -1,13 +1,13 @@ -from typing import Dict, Iterable, Tuple, Type +from collections.abc import Iterable from .response_config import ResponseConfig class _ResponsesIndex: def __init__(self) -> None: - self._items: Dict[str, Type[ResponseConfig]] = {} + self._items: dict[str, type[ResponseConfig]] = {} - def add_response_type(self, response_cls: Type[ResponseConfig]) -> None: + def add_response_type(self, response_cls: type[ResponseConfig]) -> None: if not issubclass(response_cls, ResponseConfig): raise ValueError("Response type must be subclass of ResponseConfig") @@ -21,16 +21,16 @@ def add_response_type(self, response_cls: Type[ResponseConfig]) -> None: self._items[clsname] = response_cls - def values(self) -> Iterable[Type[ResponseConfig]]: + def values(self) -> Iterable[type[ResponseConfig]]: return self._items.values() - def items(self) -> Iterable[Tuple[str, Type[ResponseConfig]]]: + def items(self) -> Iterable[tuple[str, type[ResponseConfig]]]: return self._items.items() def keys(self) -> Iterable[str]: return self._items.keys() - def __getitem__(self, item: str) -> Type[ResponseConfig]: + def __getitem__(self, item: str) -> type[ResponseConfig]: return self._items[item] def __contains__(self, item: str) -> bool: diff --git a/src/ert/config/summary_config.py b/src/ert/config/summary_config.py index bff347a6b87..0e9f9e49282 100644 --- a/src/ert/config/summary_config.py +++ b/src/ert/config/summary_config.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Any, Optional, Set, Union, no_type_check +from typing import TYPE_CHECKING, Any, no_type_check from ert.substitutions import substitute_runpath_name @@ -15,7 +15,7 @@ from .responses_index import responses_index if TYPE_CHECKING: - from typing import List + pass logger = logging.getLogger(__name__) import polars @@ -24,7 +24,7 @@ @dataclass class SummaryConfig(ResponseConfig): name: str = "summary" - refcase: Union[Set[datetime], List[str], None] = None + refcase: set[datetime] | list[str] | None = None has_finalized_keys = False def __post_init__(self) -> None: @@ -35,7 +35,7 @@ def __post_init__(self) -> None: raise ValueError("SummaryConfig must be given at least one key") @property - def expected_input_files(self) -> List[str]: + def expected_input_files(self) -> list[str]: base = self.input_files[0] return [f"{base}.UNSMRY", f"{base}.SMSPEC"] @@ -68,15 +68,15 @@ def response_type(self) -> str: return "summary" @property - def primary_key(self) -> List[str]: + def primary_key(self) -> list[str]: return ["time"] @no_type_check @classmethod - def from_config_dict(cls, config_dict: ConfigDict) -> Optional[SummaryConfig]: + def from_config_dict(cls, config_dict: ConfigDict) -> SummaryConfig | None: refcase = Refcase.from_config_dict(config_dict) if summary_keys := config_dict.get(ConfigKeys.SUMMARY, []): - eclbase: Optional[str] = config_dict.get("ECLBASE") + eclbase: str | None = config_dict.get("ECLBASE") if eclbase is None: raise ConfigValidationError( "In order to use summary responses, ECLBASE has to be set." diff --git a/src/ert/config/workflow.py b/src/ert/config/workflow.py index 7f5d3f3d93a..9d52c447acc 100644 --- a/src/ert/config/workflow.py +++ b/src/ert/config/workflow.py @@ -1,8 +1,9 @@ from __future__ import annotations import os +from collections.abc import Iterator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple +from typing import TYPE_CHECKING, Any from .parsing import ConfigValidationError, ErrorInfo, init_workflow_schema, parse @@ -15,24 +16,24 @@ @dataclass class Workflow: src_file: str - cmd_list: List[Tuple[WorkflowJob, Any]] + cmd_list: list[tuple[WorkflowJob, Any]] def __len__(self) -> int: return len(self.cmd_list) - def __getitem__(self, index: int) -> Tuple[WorkflowJob, Any]: + def __getitem__(self, index: int) -> tuple[WorkflowJob, Any]: return self.cmd_list[index] - def __iter__(self) -> Iterator[Tuple[WorkflowJob, Any]]: + def __iter__(self) -> Iterator[tuple[WorkflowJob, Any]]: return iter(self.cmd_list) @classmethod def _parse_command_list( cls, src_file: str, - context: List[Tuple[str, str]], - job_dict: Dict[str, WorkflowJob], - ) -> List[Tuple[WorkflowJob, Any]]: + context: list[tuple[str, str]], + job_dict: dict[str, WorkflowJob], + ) -> list[tuple[WorkflowJob, Any]]: schema = init_workflow_schema() config_dict = parse(src_file, schema, pre_defines=context) @@ -86,9 +87,9 @@ def _parse_command_list( def from_file( cls, src_file: str, - context: Optional[Substitutions], - job_dict: Dict[str, WorkflowJob], - ) -> "Workflow": + context: Substitutions | None, + job_dict: dict[str, WorkflowJob], + ) -> Workflow: if not os.path.exists(src_file): raise ConfigValidationError.with_context( f"Workflow file {src_file} does not exist", src_file diff --git a/src/ert/config/workflow_job.py b/src/ert/config/workflow_job.py index ab450c31515..aca7568a9c5 100644 --- a/src/ert/config/workflow_job.py +++ b/src/ert/config/workflow_job.py @@ -3,7 +3,7 @@ import logging import os from dataclasses import dataclass -from typing import Type, TypeAlias +from typing import TypeAlias from .ert_plugin import ErtPlugin from .ert_script import ErtScript @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -ContentTypes: TypeAlias = Type[int] | Type[bool] | Type[float] | Type[str] +ContentTypes: TypeAlias = type[int] | type[bool] | type[float] | type[str] def workflow_job_parser(file: str) -> ConfigDict: @@ -72,7 +72,7 @@ def _make_arg_types_list(content_dict: ConfigDict) -> list[SchemaItemType]: ) @classmethod - def from_file(cls, config_file: str, name: str | None = None) -> "WorkflowJob": + def from_file(cls, config_file: str, name: str | None = None) -> WorkflowJob: if not (os.path.isfile(config_file) and os.access(config_file, os.R_OK)): raise ConfigValidationError(f"Could not open config_file:{config_file!r}") if not name: @@ -100,7 +100,7 @@ def is_plugin(self) -> bool: return issubclass(self.ert_script, ErtPlugin) return False - def argument_types(self) -> list["ContentTypes"]: + def argument_types(self) -> list[ContentTypes]: def content_to_type(c: SchemaItemType | None) -> ContentTypes: if c == SchemaItemType.BOOL: return bool diff --git a/src/ert/dark_storage/client/_session.py b/src/ert/dark_storage/client/_session.py index b10178f559e..b62892164ab 100644 --- a/src/ert/dark_storage/client/_session.py +++ b/src/ert/dark_storage/client/_session.py @@ -1,14 +1,13 @@ import json import os from pathlib import Path -from typing import Optional from pydantic import BaseModel, ValidationError class ConnInfo(BaseModel): base_url: str - auth_token: Optional[str] = None + auth_token: str | None = None ENV_VAR = "ERT_STORAGE_CONNECTION_STRING" @@ -17,7 +16,7 @@ class ConnInfo(BaseModel): # that a single client process will only ever want to connect to a single ERT # Storage server during its lifetime, so we don't provide an API for managing # this cache. -_CACHED_CONN_INFO: Optional[ConnInfo] = None +_CACHED_CONN_INFO: ConnInfo | None = None def find_conn_info() -> ConnInfo: diff --git a/src/ert/dark_storage/client/async_client.py b/src/ert/dark_storage/client/async_client.py index 86cda55aa6a..8ca6060a340 100644 --- a/src/ert/dark_storage/client/async_client.py +++ b/src/ert/dark_storage/client/async_client.py @@ -1,5 +1,3 @@ -from typing import Optional - import httpx from ._session import ConnInfo, find_conn_info @@ -11,7 +9,7 @@ class AsyncClient(httpx.AsyncClient): interact with ERT Storage's API """ - def __init__(self, conn_info: Optional[ConnInfo] = None) -> None: + def __init__(self, conn_info: ConnInfo | None = None) -> None: if conn_info is None: conn_info = find_conn_info() diff --git a/src/ert/dark_storage/client/client.py b/src/ert/dark_storage/client/client.py index ad295878356..4b2e64b0ba7 100644 --- a/src/ert/dark_storage/client/client.py +++ b/src/ert/dark_storage/client/client.py @@ -1,5 +1,3 @@ -from typing import Optional - import httpx from ._session import ConnInfo, find_conn_info @@ -11,7 +9,7 @@ class Client(httpx.Client): interact with ERT Storage's API """ - def __init__(self, conn_info: Optional[ConnInfo] = None) -> None: + def __init__(self, conn_info: ConnInfo | None = None) -> None: if conn_info is None: conn_info = find_conn_info() diff --git a/src/ert/dark_storage/common.py b/src/ert/dark_storage/common.py index f2a6faa0ddb..629f6d46c81 100644 --- a/src/ert/dark_storage/common.py +++ b/src/ert/dark_storage/common.py @@ -1,6 +1,7 @@ import contextlib import logging -from typing import Any, Callable, Iterator +from collections.abc import Callable, Iterator +from typing import Any from uuid import UUID import numpy as np diff --git a/src/ert/dark_storage/compute/misfits.py b/src/ert/dark_storage/compute/misfits.py index c4cce43220f..26c0c86de12 100644 --- a/src/ert/dark_storage/compute/misfits.py +++ b/src/ert/dark_storage/compute/misfits.py @@ -1,4 +1,4 @@ -from typing import Mapping +from collections.abc import Mapping import numpy as np import numpy.typing as npt diff --git a/src/ert/dark_storage/endpoints/compute/misfits.py b/src/ert/dark_storage/endpoints/compute/misfits.py index 799e4c535bd..eb97e0d84dd 100644 --- a/src/ert/dark_storage/endpoints/compute/misfits.py +++ b/src/ert/dark_storage/endpoints/compute/misfits.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Optional, Union +from typing import Any from uuid import UUID import pandas as pd @@ -34,7 +34,7 @@ async def get_response_misfits( storage: Storage = DEFAULT_STORAGEREADER, ensemble_id: UUID, response_name: str, - realization_index: Optional[int] = None, + realization_index: int | None = None, summary_misfits: bool = False, ) -> Response: ensemble = storage.get_ensemble(ensemble_id) @@ -56,7 +56,7 @@ async def get_response_misfits( raise ValueError(f"Cant fetch observations for key {response_name}") o = obs[0] - def parse_index(x: Any) -> Union[int, datetime]: + def parse_index(x: Any) -> int | datetime: try: return int(x) except ValueError: diff --git a/src/ert/dark_storage/endpoints/experiments.py b/src/ert/dark_storage/endpoints/experiments.py index dcb3e01b4db..d0f61fe03d6 100644 --- a/src/ert/dark_storage/endpoints/experiments.py +++ b/src/ert/dark_storage/endpoints/experiments.py @@ -1,4 +1,3 @@ -from typing import List from uuid import UUID from fastapi import APIRouter, Body, Depends @@ -14,11 +13,11 @@ DEFAULT_BODY = Body(...) -@router.get("/experiments", response_model=List[js.ExperimentOut]) +@router.get("/experiments", response_model=list[js.ExperimentOut]) def get_experiments( *, storage: Storage = DEFAULT_STORAGE, -) -> List[js.ExperimentOut]: +) -> list[js.ExperimentOut]: return [ js.ExperimentOut( id=experiment.id, @@ -48,13 +47,13 @@ def get_experiment_by_id( @router.get( - "/experiments/{experiment_id}/ensembles", response_model=List[js.EnsembleOut] + "/experiments/{experiment_id}/ensembles", response_model=list[js.EnsembleOut] ) def get_experiment_ensembles( *, storage: Storage = DEFAULT_STORAGE, experiment_id: UUID, -) -> List[js.EnsembleOut]: +) -> list[js.EnsembleOut]: return [ js.EnsembleOut( id=ensemble.id, diff --git a/src/ert/dark_storage/endpoints/observations.py b/src/ert/dark_storage/endpoints/observations.py index 7e98ba3b310..172d9d59987 100644 --- a/src/ert/dark_storage/endpoints/observations.py +++ b/src/ert/dark_storage/endpoints/observations.py @@ -1,4 +1,3 @@ -from typing import List from uuid import UUID from fastapi import APIRouter, Body, Depends @@ -15,11 +14,11 @@ @router.get( - "/experiments/{experiment_id}/observations", response_model=List[js.ObservationOut] + "/experiments/{experiment_id}/observations", response_model=list[js.ObservationOut] ) def get_observations( *, storage: Storage = DEFAULT_STORAGE, experiment_id: UUID -) -> List[js.ObservationOut]: +) -> list[js.ObservationOut]: experiment = storage.get_experiment(experiment_id) return [ js.ObservationOut( diff --git a/src/ert/dark_storage/endpoints/records.py b/src/ert/dark_storage/endpoints/records.py index bd8a347a6a1..923968a60e2 100644 --- a/src/ert/dark_storage/endpoints/records.py +++ b/src/ert/dark_storage/endpoints/records.py @@ -1,5 +1,6 @@ import io -from typing import Annotated, Any, Mapping, Union +from collections.abc import Mapping +from typing import Annotated, Any from urllib.parse import unquote from uuid import UUID, uuid4 @@ -78,7 +79,7 @@ async def get_ensemble_record( storage: Storage = DEFAULT_STORAGE, name: str, ensemble_id: UUID, - accept: Annotated[Union[str, None], Header()] = None, + accept: Annotated[str | None, Header()] = None, ) -> Any: name = unquote(name) try: diff --git a/src/ert/dark_storage/enkf.py b/src/ert/dark_storage/enkf.py index 27de2ca4541..58c9c6b4bab 100644 --- a/src/ert/dark_storage/enkf.py +++ b/src/ert/dark_storage/enkf.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -from typing import Optional from fastapi import Depends @@ -12,7 +11,7 @@ __all__ = ["get_storage"] -_storage: Optional[Storage] = None +_storage: Storage | None = None DEFAULT_SECURITY = Depends(security) diff --git a/src/ert/dark_storage/json_schema/ensemble.py b/src/ert/dark_storage/json_schema/ensemble.py index 891bc848c17..8e9a6f91e58 100644 --- a/src/ert/dark_storage/json_schema/ensemble.py +++ b/src/ert/dark_storage/json_schema/ensemble.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from uuid import UUID from pydantic import BaseModel diff --git a/src/ert/dark_storage/json_schema/experiment.py b/src/ert/dark_storage/json_schema/experiment.py index ead15446373..e6deee8801b 100644 --- a/src/ert/dark_storage/json_schema/experiment.py +++ b/src/ert/dark_storage/json_schema/experiment.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from uuid import UUID from pydantic import ConfigDict, Field diff --git a/src/ert/dark_storage/json_schema/observation.py b/src/ert/dark_storage/json_schema/observation.py index 2b829a32167..f14ac673fb9 100644 --- a/src/ert/dark_storage/json_schema/observation.py +++ b/src/ert/dark_storage/json_schema/observation.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any from uuid import UUID, uuid4 from pydantic import ConfigDict, Field diff --git a/src/ert/dark_storage/json_schema/prior.py b/src/ert/dark_storage/json_schema/prior.py index 3a240e9ab3a..91cc0707a44 100644 --- a/src/ert/dark_storage/json_schema/prior.py +++ b/src/ert/dark_storage/json_schema/prior.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Literal from pydantic import BaseModel @@ -131,16 +131,16 @@ class PriorErtDErf(BaseModel): width: float -Prior = Union[ - PriorConst, - PriorTrig, - PriorNormal, - PriorLogNormal, - PriorErtTruncNormal, - PriorStdNormal, - PriorUniform, - PriorErtDUniform, - PriorLogUniform, - PriorErtErf, - PriorErtDErf, -] +Prior = ( + PriorConst + | PriorTrig + | PriorNormal + | PriorLogNormal + | PriorErtTruncNormal + | PriorStdNormal + | PriorUniform + | PriorErtDUniform + | PriorLogUniform + | PriorErtErf + | PriorErtDErf +) diff --git a/src/ert/dark_storage/json_schema/record.py b/src/ert/dark_storage/json_schema/record.py index 2ce03df2498..7daf360be8e 100644 --- a/src/ert/dark_storage/json_schema/record.py +++ b/src/ert/dark_storage/json_schema/record.py @@ -1,4 +1,5 @@ -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from uuid import UUID from pydantic import ConfigDict @@ -15,4 +16,4 @@ class RecordOut(_Record): id: UUID name: str userdata: Mapping[str, Any] - has_observations: Optional[bool] + has_observations: bool | None diff --git a/src/ert/dark_storage/json_schema/update.py b/src/ert/dark_storage/json_schema/update.py index 3bb7f6094a5..b580f63e8a9 100644 --- a/src/ert/dark_storage/json_schema/update.py +++ b/src/ert/dark_storage/json_schema/update.py @@ -3,9 +3,7 @@ from pydantic import ConfigDict from pydantic.dataclasses import dataclass -from .observation import ( - ObservationTransformationIn, -) +from .observation import ObservationTransformationIn @dataclass diff --git a/src/ert/dark_storage/security.py b/src/ert/dark_storage/security.py index f410b8047d9..6d81d7fbb38 100644 --- a/src/ert/dark_storage/security.py +++ b/src/ert/dark_storage/security.py @@ -1,5 +1,4 @@ import os -from typing import Optional from fastapi import HTTPException, Security, status from fastapi.security import APIKeyHeader @@ -8,7 +7,7 @@ _security_header = APIKeyHeader(name="Token", auto_error=False) -async def security(*, token: Optional[str] = Security(_security_header)) -> None: +async def security(*, token: str | None = Security(_security_header)) -> None: if os.getenv("ERT_STORAGE_NO_TOKEN"): return if not token: diff --git a/src/ert/data/_measured_data.py b/src/ert/data/_measured_data.py index beec3282446..0fa7444b532 100644 --- a/src/ert/data/_measured_data.py +++ b/src/ert/data/_measured_data.py @@ -8,7 +8,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -26,7 +26,7 @@ class MeasuredData: def __init__( self, ensemble: Ensemble, - keys: Optional[List[str]] = None, + keys: list[str] | None = None, ): if keys is None: keys = sorted(ensemble.experiment.observation_keys) @@ -83,7 +83,7 @@ def is_empty(self) -> bool: @staticmethod def _get_data( ensemble: Ensemble, - observation_keys: List[str], + observation_keys: list[str], ) -> pd.DataFrame: """ Adds simulated and observed data and returns a dataframe where ensemble diff --git a/src/ert/enkf_main.py b/src/ert/enkf_main.py index c6511829daa..4b83db6cb7e 100644 --- a/src/ert/enkf_main.py +++ b/src/ert/enkf_main.py @@ -4,9 +4,10 @@ import logging import os import time +from collections.abc import Iterable, Mapping from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping +from typing import TYPE_CHECKING, Any import orjson from numpy.random import SeedSequence @@ -65,7 +66,7 @@ def _value_export_json( return # Hierarchical - json_out: Dict[str, float | Dict[str, float]] = { + json_out: dict[str, float | dict[str, float]] = { key: dict(param_map.items()) for key, param_map in values.items() } @@ -99,7 +100,7 @@ def _generate_parameter_files( iens: Realisation index fs: Ensemble from which to load parameter data """ - exports: Dict[str, Dict[str, float]] = {} + exports: dict[str, dict[str, float]] = {} for node in parameter_configs: # For the first iteration we do not write the parameter @@ -116,13 +117,13 @@ def _generate_parameter_files( _value_export_json(run_path, export_base_name, exports) -def _manifest_to_json(ensemble: Ensemble, iens: int, iter: int) -> Dict[str, Any]: +def _manifest_to_json(ensemble: Ensemble, iens: int, iter: int) -> dict[str, Any]: manifest = {} # Add expected parameter files to manifest for param_config in ensemble.experiment.parameter_configuration.values(): assert isinstance( param_config, - (ExtParamConfig, GenKwConfig, Field, SurfaceConfig), + ExtParamConfig | GenKwConfig | Field | SurfaceConfig, ) if param_config.forward_init and ensemble.iteration == 0: assert param_config.forward_init_file is not None @@ -195,14 +196,14 @@ def create_run_path( run_args: list[RunArg], ensemble: Ensemble, user_config_file: str, - env_vars: Dict[str, str], - env_pr_fm_step: Dict[str, Dict[str, Any]], + env_vars: dict[str, str], + env_pr_fm_step: dict[str, dict[str, Any]], forward_model_steps: list[ForwardModelStep], substitutions: Substitutions, templates: list[tuple[str, str]], model_config: ModelConfig, runpaths: Runpaths, - context_env: Dict[str, str] | None = None, + context_env: dict[str, str] | None = None, ) -> None: if context_env is None: context_env = {} diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index ed079ae9ea8..da9af3aec73 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -3,14 +3,12 @@ import asyncio import logging import traceback +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from functools import partialmethod from typing import ( Any, - Awaitable, - Callable, Protocol, - Sequence, ) from _ert.events import ( @@ -185,7 +183,7 @@ def update_snapshot(self, events: Sequence[Event]) -> EnsembleSnapshot: self.status = self._status_tracker.update_state(self.snapshot.status) for event in events: - if isinstance(event, (ForwardModelStepSuccess, ForwardModelStepFailure)): + if isinstance(event, ForwardModelStepSuccess | ForwardModelStepFailure): step = ( self.snapshot.reals[event.real] .get("fm_steps", {}) @@ -300,7 +298,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches self._scheduler.add_dispatch_information_to_jobs_file() result = await self._scheduler.execute(min_required_realizations) except PermissionError as error: - logger.exception((f"Unexpected exception in ensemble: \n {error!s}")) + logger.exception(f"Unexpected exception in ensemble: \n {error!s}") await event_unary_send(event_creator(Id.ENSEMBLE_FAILED)) return except Exception as exc: @@ -340,7 +338,7 @@ class Realization: fm_steps: Sequence[ForwardModelStep] active: bool max_runtime: int | None - run_arg: "RunArg" + run_arg: RunArg num_cpu: int job_script: str realization_memory: int # Memory to reserve/book, in bytes diff --git a/src/ert/ensemble_evaluator/_wait_for_evaluator.py b/src/ert/ensemble_evaluator/_wait_for_evaluator.py index 4cbde527e0d..f97fb758a6b 100644 --- a/src/ert/ensemble_evaluator/_wait_for_evaluator.py +++ b/src/ert/ensemble_evaluator/_wait_for_evaluator.py @@ -2,7 +2,6 @@ import logging import ssl import time -from typing import Optional, Union import aiohttp @@ -11,7 +10,7 @@ WAIT_FOR_EVALUATOR_TIMEOUT = 60 -def get_ssl_context(cert: Optional[Union[str, bytes]]) -> Union[ssl.SSLContext, bool]: +def get_ssl_context(cert: str | bytes | None) -> ssl.SSLContext | bool: if cert is None: return False ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) @@ -21,8 +20,8 @@ def get_ssl_context(cert: Optional[Union[str, bytes]]) -> Union[ssl.SSLContext, async def attempt_connection( url: str, - token: Optional[str] = None, - cert: Optional[Union[str, bytes]] = None, + token: str | None = None, + cert: str | bytes | None = None, connection_timeout: float = 2, ) -> None: timeout = aiohttp.ClientTimeout(connect=connection_timeout) @@ -42,10 +41,10 @@ async def attempt_connection( async def wait_for_evaluator( base_url: str, - token: Optional[str] = None, - cert: Optional[Union[str, bytes]] = None, + token: str | None = None, + cert: str | bytes | None = None, healthcheck_endpoint: str = "/healthcheck", - timeout: Optional[float] = None, # noqa: ASYNC109 + timeout: float | None = None, # noqa: ASYNC109 connection_timeout: float = 2, ) -> None: if timeout is None: diff --git a/src/ert/ensemble_evaluator/config.py b/src/ert/ensemble_evaluator/config.py index 536ed761c73..51b059a6ce1 100644 --- a/src/ert/ensemble_evaluator/config.py +++ b/src/ert/ensemble_evaluator/config.py @@ -5,11 +5,9 @@ import socket import ssl import tempfile -import typing import warnings from base64 import b64encode -from datetime import datetime, timedelta, timezone -from typing import Optional +from datetime import UTC, datetime, timedelta from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -43,7 +41,7 @@ def _generate_authentication() -> str: def _generate_certificate( ip_address: str, -) -> typing.Tuple[str, bytes, bytes]: +) -> tuple[str, bytes, bytes]: """Generate a private key and a certificate signed with it The key is encrypted before being stored. Returns the certificate as a string, the key as bytes (encrypted), and @@ -71,8 +69,8 @@ def _generate_certificate( .issuer_name(issuer) .public_key(key.public_key()) .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.now(timezone.utc)) - .not_valid_after(datetime.now(timezone.utc) + timedelta(days=365)) # 1 year + .not_valid_before(datetime.now(UTC)) + .not_valid_after(datetime.now(UTC) + timedelta(days=365)) # 1 year .add_extension( x509.SubjectAlternativeName( [ @@ -123,10 +121,10 @@ class EvaluatorServerConfig: def __init__( self, - custom_port_range: typing.Optional[range] = None, + custom_port_range: range | None = None, use_token: bool = True, generate_cert: bool = True, - custom_host: typing.Optional[str] = None, + custom_host: str | None = None, ) -> None: self._socket_handle = find_available_socket( custom_range=custom_port_range, custom_host=custom_host @@ -141,7 +139,7 @@ def __init__( else: cert, key, pw = None, None, None self.cert = cert - self._key: Optional[bytes] = key + self._key: bytes | None = key self._key_pw = pw self.token = _generate_authentication() if use_token else None @@ -158,7 +156,7 @@ def get_connection_info(self) -> EvaluatorConnectionInfo: def get_server_ssl_context( self, protocol: int = ssl.PROTOCOL_TLS_SERVER - ) -> typing.Optional[ssl.SSLContext]: + ) -> ssl.SSLContext | None: if self.cert is None: return None with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 7f89c36dd7f..3a6b92a77a6 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -2,17 +2,18 @@ import datetime import logging import traceback -from contextlib import asynccontextmanager, contextmanager -from http import HTTPStatus -from typing import ( - Any, +from collections.abc import ( AsyncIterator, Awaitable, Callable, Generator, Iterable, Sequence, - Type, +) +from contextlib import asynccontextmanager, contextmanager +from http import HTTPStatus +from typing import ( + Any, get_args, ) @@ -112,9 +113,9 @@ async def _process_event_buffer(self) -> None: self._batch_processing_queue.task_done() async def _batch_events_into_buffer(self) -> None: - event_handler: dict[Type[Event], EVENT_HANDLER] = {} + event_handler: dict[type[Event], EVENT_HANDLER] = {} - def set_event_handler(event_types: set[Type[Event]], func: Any) -> None: + def set_event_handler(event_types: set[type[Event]], func: Any) -> None: for event_type in event_types: event_handler[event_type] = func @@ -138,7 +139,7 @@ def set_event_handler(event_types: set[Type[Event]], func: Any) -> None: function = event_handler[type(event)] batch.append((function, event)) self._events.task_done() - except asyncio.TimeoutError: + except TimeoutError: continue self._complete_batch.set() await self._batch_processing_queue.put(batch) @@ -315,7 +316,7 @@ async def _server(self) -> None: await asyncio.wait_for( self._dispatchers_connected.join(), timeout=20 ) - except asyncio.TimeoutError: + except TimeoutError: logger.debug("Timed out waiting for dispatchers to disconnect") else: logger.debug("Got done signal. No dispatchers connected") @@ -414,10 +415,8 @@ def log_exception(task_exception: BaseException, task_name: str) -> None: ) ) logger.error( - ( - f"Exception in evaluator task {task_name}: {task_exception}\n" - f"Traceback: {exc_traceback}" - ) + f"Exception in evaluator task {task_name}: {task_exception}\n" + f"Traceback: {exc_traceback}" ) async def run_and_get_successful_realizations(self) -> list[int]: diff --git a/src/ert/ensemble_evaluator/evaluator_connection_info.py b/src/ert/ensemble_evaluator/evaluator_connection_info.py index bd48e08e4a1..e01326c5c99 100644 --- a/src/ert/ensemble_evaluator/evaluator_connection_info.py +++ b/src/ert/ensemble_evaluator/evaluator_connection_info.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional, Union @dataclass @@ -7,8 +6,8 @@ class EvaluatorConnectionInfo: """Read only server-info""" url: str - cert: Optional[Union[str, bytes]] = None - token: Optional[str] = None + cert: str | bytes | None = None + token: str | None = None @property def dispatch_uri(self) -> str: diff --git a/src/ert/ensemble_evaluator/event.py b/src/ert/ensemble_evaluator/event.py index eaeecc99a0b..856557968f3 100644 --- a/src/ert/ensemble_evaluator/event.py +++ b/src/ert/ensemble_evaluator/event.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Dict, Optional from .snapshot import EnsembleSnapshot @@ -11,21 +10,21 @@ class _UpdateEvent: total_iterations: int progress: float realization_count: int - status_count: Dict[str, int] + status_count: dict[str, int] iteration: int @dataclass class FullSnapshotEvent(_UpdateEvent): - snapshot: Optional[EnsembleSnapshot] = None + snapshot: EnsembleSnapshot | None = None @dataclass class SnapshotUpdateEvent(_UpdateEvent): - snapshot: Optional[EnsembleSnapshot] = None + snapshot: EnsembleSnapshot | None = None @dataclass class EndEvent: failed: bool - msg: Optional[str] = None + msg: str | None = None diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index ab2ddb4c495..d3f549377c6 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -2,7 +2,8 @@ import logging import ssl import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Optional, Union +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any, Final from aiohttp import ClientError from websockets import ConnectionClosed, Headers @@ -35,9 +36,9 @@ class Monitor: def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None: self._ee_con_info = ee_con_info self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0] - self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue() - self._connection: Optional[ClientConnection] = None - self._receiver_task: Optional[asyncio.Task[None]] = None + self._event_queue: asyncio.Queue[Event | EventSentinel] = asyncio.Queue() + self._connection: ClientConnection | None = None + self._receiver_task: asyncio.Task[None] | None = None self._connected: asyncio.Future[None] = asyncio.Future() self._connection_timeout: float = 120.0 self._receiver_timeout: float = 60.0 @@ -46,7 +47,7 @@ async def __aenter__(self) -> "Monitor": self._receiver_task = asyncio.create_task(self._receiver()) try: await asyncio.wait_for(self._connected, timeout=self._connection_timeout) - except asyncio.TimeoutError as exc: + except TimeoutError as exc: msg = "Couldn't establish connection with the ensemble evaluator!" logger.error(msg) self._receiver_task.cancel() @@ -86,28 +87,28 @@ async def signal_done(self) -> None: logger.debug(f"monitor-{self._id} informed server monitor is done") async def track( - self, heartbeat_interval: Optional[float] = None - ) -> AsyncGenerator[Optional[Event], None]: + self, heartbeat_interval: float | None = None + ) -> AsyncGenerator[Event | None, None]: """Yield events from the internal event queue with optional heartbeats. Heartbeats are represented by None being yielded. Heartbeats stops being emitted after a CloseTrackerEvent is found.""" - _heartbeat_interval: Optional[float] = heartbeat_interval + heartbeat_interval_: float | None = heartbeat_interval closetracker_received: bool = False while True: try: event = await asyncio.wait_for( - self._event_queue.get(), timeout=_heartbeat_interval + self._event_queue.get(), timeout=heartbeat_interval_ ) - except asyncio.TimeoutError: + except TimeoutError: if closetracker_received: logger.error("Evaluator did not send the TERMINATED event!") break event = None if isinstance(event, EventSentinel): closetracker_received = True - _heartbeat_interval = self._receiver_timeout + heartbeat_interval_ = self._receiver_timeout else: yield event if type(event) is EETerminated: @@ -117,7 +118,7 @@ async def track( self._event_queue.task_done() async def _receiver(self) -> None: - tls: Optional[ssl.SSLContext] = None + tls: ssl.SSLContext | None = None if self._ee_con_info.cert: tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) tls.load_verify_locations(cadata=self._ee_con_info.cert) diff --git a/src/ert/ensemble_evaluator/snapshot.py b/src/ert/ensemble_evaluator/snapshot.py index 6a6bd964660..1b4c5ebb6de 100644 --- a/src/ert/ensemble_evaluator/snapshot.py +++ b/src/ert/ensemble_evaluator/snapshot.py @@ -1,9 +1,10 @@ from __future__ import annotations import logging -from collections import defaultdict +from collections import Counter, defaultdict +from collections.abc import Mapping from datetime import datetime -from typing import Any, Counter, Mapping, TypeVar, cast, get_args +from typing import Any, TypeVar, cast, get_args from qtpy.QtGui import QColor from typing_extensions import TypedDict @@ -113,7 +114,7 @@ def __init__(self) -> None: ) @classmethod - def from_nested_dict(cls, source: Mapping[Any, Any]) -> "EnsembleSnapshot": + def from_nested_dict(cls, source: Mapping[Any, Any]) -> EnsembleSnapshot: ensemble = EnsembleSnapshot() if "metadata" in source: ensemble._metadata = source["metadata"] @@ -126,7 +127,7 @@ def from_nested_dict(cls, source: Mapping[Any, Any]) -> "EnsembleSnapshot": return ensemble def add_realization( - self, real_id: RealId, realization: "RealizationSnapshot" + self, real_id: RealId, realization: RealizationSnapshot ) -> None: self._realization_snapshots[real_id] = realization @@ -134,7 +135,7 @@ def add_realization( fm_step_idx = (real_id, fm_step_id) self._fm_step_snapshots[fm_step_idx] = fm_step_snapshot - def merge_snapshot(self, ensemble: "EnsembleSnapshot") -> "EnsembleSnapshot": + def merge_snapshot(self, ensemble: EnsembleSnapshot) -> EnsembleSnapshot: self._metadata.update(ensemble._metadata) if ensemble._ensemble_state is not None: self._ensemble_state = ensemble._ensemble_state @@ -149,25 +150,25 @@ def merge_metadata(self, metadata: EnsembleSnapshotMetadata) -> None: def to_dict(self) -> dict[str, Any]: """used to send snapshot updates""" - _dict: dict[str, Any] = {} + dict_: dict[str, Any] = {} if self._metadata: - _dict["metadata"] = self._metadata + dict_["metadata"] = self._metadata if self._ensemble_state: - _dict["status"] = self._ensemble_state + dict_["status"] = self._ensemble_state if self._realization_snapshots: - _dict["reals"] = self._realization_snapshots + dict_["reals"] = self._realization_snapshots for (real_id, fm_id), fm_values_dict in self._fm_step_snapshots.items(): - if "reals" not in _dict: - _dict["reals"] = {} - if real_id not in _dict["reals"]: - _dict["reals"][real_id] = RealizationSnapshot(fm_steps={}) - if "fm_steps" not in _dict["reals"][real_id]: - _dict["reals"][real_id]["fm_steps"] = {} + if "reals" not in dict_: + dict_["reals"] = {} + if real_id not in dict_["reals"]: + dict_["reals"][real_id] = RealizationSnapshot(fm_steps={}) + if "fm_steps" not in dict_["reals"][real_id]: + dict_["reals"][real_id]["fm_steps"] = {} - _dict["reals"][real_id]["fm_steps"][fm_id] = fm_values_dict + dict_["reals"][real_id]["fm_steps"][fm_id] = fm_values_dict - return _dict + return dict_ @property def status(self) -> str | None: @@ -179,7 +180,7 @@ def metadata(self) -> EnsembleSnapshotMetadata: def get_all_fm_steps( self, - ) -> Mapping[tuple[RealId, FmStepId], "FMStepSnapshot"]: + ) -> Mapping[tuple[RealId, FmStepId], FMStepSnapshot]: return self._fm_step_snapshots.copy() def get_fm_steps_for_all_reals( @@ -192,22 +193,20 @@ def get_fm_steps_for_all_reals( } @property - def reals(self) -> Mapping[RealId, "RealizationSnapshot"]: + def reals(self) -> Mapping[RealId, RealizationSnapshot]: return self._realization_snapshots - def get_fm_steps_for_real( - self, real_id: RealId - ) -> dict[FmStepId, "FMStepSnapshot"]: + def get_fm_steps_for_real(self, real_id: RealId) -> dict[FmStepId, FMStepSnapshot]: return { fm_step_idx[1]: fm_step_snapshot.copy() for fm_step_idx, fm_step_snapshot in self._fm_step_snapshots.items() if fm_step_idx[0] == real_id } - def get_real(self, real_id: RealId) -> "RealizationSnapshot": + def get_real(self, real_id: RealId) -> RealizationSnapshot: return self._realization_snapshots[real_id] - def get_fm_step(self, real_id: RealId, fm_step_id: FmStepId) -> "FMStepSnapshot": + def get_fm_step(self, real_id: RealId, fm_step_id: FmStepId) -> FMStepSnapshot: return self._fm_step_snapshots[real_id, fm_step_id].copy() def get_successful_realizations(self) -> list[int]: @@ -219,11 +218,9 @@ def get_successful_realizations(self) -> list[int]: def aggregate_real_states(self) -> Counter[str]: counter = Counter( - ( - real["status"] - for real in self._realization_snapshots.values() - if real.get("status") is not None - ) + real["status"] + for real in self._realization_snapshots.values() + if real.get("status") is not None ) return counter # type: ignore @@ -239,7 +236,7 @@ def update_realization( end_time: datetime | None = None, exec_hosts: str | None = None, message: str | None = None, - ) -> "EnsembleSnapshot": + ) -> EnsembleSnapshot: self._realization_snapshots[real_id].update( _filter_nones( RealizationSnapshot( @@ -255,7 +252,7 @@ def update_realization( def update_from_event( self, event: Event, source_snapshot: EnsembleSnapshot | None = None - ) -> "EnsembleSnapshot": + ) -> EnsembleSnapshot: e_type = type(event) timestamp = event.time @@ -362,8 +359,8 @@ def update_fm_step( self, real_id: str, fm_step_id: str, - fm_step: "FMStepSnapshot", - ) -> "EnsembleSnapshot": + fm_step: FMStepSnapshot, + ) -> EnsembleSnapshot: self._fm_step_snapshots[real_id, fm_step_id].update(fm_step) return self diff --git a/src/ert/field_utils/grdecl_io.py b/src/ert/field_utils/grdecl_io.py index cf57b69f255..9620c964a0c 100644 --- a/src/ert/field_utils/grdecl_io.py +++ b/src/ert/field_utils/grdecl_io.py @@ -1,8 +1,9 @@ from __future__ import annotations import os +from collections.abc import Iterator from contextlib import contextmanager -from typing import Any, Iterator, TextIO +from typing import Any, TextIO import numpy as np import numpy.typing as npt @@ -146,7 +147,7 @@ def read_grdecl(grdecl_stream: TextIO) -> Iterator[tuple[str, list[str]]]: if keyword is not None: raise ValueError(f"Reached end of stream while reading {keyword}") - with open(grdecl_file, "r", encoding="utf-8") as stream: + with open(grdecl_file, encoding="utf-8") as stream: yield read_grdecl(stream) diff --git a/src/ert/field_utils/roff_io.py b/src/ert/field_utils/roff_io.py index a92aececa78..514a9395ce1 100644 --- a/src/ert/field_utils/roff_io.py +++ b/src/ert/field_utils/roff_io.py @@ -2,7 +2,7 @@ import warnings from collections import OrderedDict -from typing import TYPE_CHECKING, Any, BinaryIO, Optional, TextIO, Tuple, Union +from typing import TYPE_CHECKING, Any, BinaryIO, TextIO import numpy as np import roffio # type: ignore @@ -10,7 +10,7 @@ if TYPE_CHECKING: from os import PathLike - _PathLike = Union[str, PathLike[str]] + _PathLike = str | PathLike[str] RMS_UNDEFINED_FLOAT = np.float32(-999.0) @@ -18,7 +18,7 @@ def export_roff( data: np.ma.MaskedArray[Any, np.dtype[np.float32]], - filelike: Union[TextIO, BinaryIO, _PathLike], + filelike: TextIO | BinaryIO | _PathLike, parameter_name: str, binary: bool, ) -> None: @@ -49,7 +49,7 @@ def export_roff( def import_roff( - filelike: Union[TextIO, BinaryIO, _PathLike], name: Optional[str] = None + filelike: TextIO | BinaryIO | _PathLike, name: str | None = None ) -> np.ma.MaskedArray[Any, np.dtype[np.float32]]: looking_for = { "dimensions": { @@ -69,7 +69,7 @@ def reset_parameter() -> None: def all_set() -> bool: return all(val is not None for v in looking_for.values() for val in v.values()) - def should_skip_parameter(key: Tuple[str, str]) -> bool: + def should_skip_parameter(key: tuple[str, str]) -> bool: return key[0] == "name" and name is not None and key[1] != name with roffio.lazy_read(filelike) as tag_generator: diff --git a/src/ert/gui/about_dialog.py b/src/ert/gui/about_dialog.py index e851ab23d6d..87b96959dc9 100644 --- a/src/ert/gui/about_dialog.py +++ b/src/ert/gui/about_dialog.py @@ -1,5 +1,3 @@ -from typing import Optional - from qtpy.QtCore import QSize, Qt from qtpy.QtGui import QFont from qtpy.QtWidgets import ( @@ -15,7 +13,7 @@ class AboutDialog(QDialog): - def __init__(self, parent: Optional[QWidget]) -> None: + def __init__(self, parent: QWidget | None) -> None: QDialog.__init__(self, parent) self.setWindowTitle("About") diff --git a/src/ert/gui/ertnotifier.py b/src/ert/gui/ertnotifier.py index 6c925f353bc..274674c8329 100644 --- a/src/ert/gui/ertnotifier.py +++ b/src/ert/gui/ertnotifier.py @@ -1,5 +1,3 @@ -from typing import Optional - from qtpy.QtCore import QObject, Signal, Slot from ert.storage import Ensemble, Storage @@ -13,8 +11,8 @@ class ErtNotifier(QObject): def __init__(self, config_file: str): QObject.__init__(self) self._config_file = config_file - self._storage: Optional[Storage] = None - self._current_ensemble: Optional[Ensemble] = None + self._storage: Storage | None = None + self._current_ensemble: Ensemble | None = None self._is_simulation_running = False @property @@ -31,7 +29,7 @@ def config_file(self) -> str: return self._config_file @property - def current_ensemble(self) -> Optional[Ensemble]: + def current_ensemble(self) -> Ensemble | None: if self._current_ensemble is None and self._storage is not None: ensembles = list(self._storage.ensembles) if ensembles: @@ -58,7 +56,7 @@ def set_storage(self, storage: Storage) -> None: self.storage_changed.emit(storage) @Slot(object) - def set_current_ensemble(self, ensemble: Optional[Ensemble] = None) -> None: + def set_current_ensemble(self, ensemble: Ensemble | None = None) -> None: self._current_ensemble = ensemble self.current_ensemble_changed.emit(ensemble) diff --git a/src/ert/gui/ertwidgets/__init__.py b/src/ert/gui/ertwidgets/__init__.py index 00253078b04..cc1bf211dbb 100644 --- a/src/ert/gui/ertwidgets/__init__.py +++ b/src/ert/gui/ertwidgets/__init__.py @@ -2,7 +2,8 @@ from qtpy.QtCore import Qt from qtpy.QtGui import QCursor from qtpy.QtWidgets import QApplication -from typing import Callable, Any +from typing import Any +from collections.abc import Callable def showWaitCursorWhileWaiting(func: Callable[..., Any]) -> Callable[..., Any]: diff --git a/src/ert/gui/ertwidgets/analysismodulevariablespanel.py b/src/ert/gui/ertwidgets/analysismodulevariablespanel.py index ded4b1ebba3..0bd78746f0c 100644 --- a/src/ert/gui/ertwidgets/analysismodulevariablespanel.py +++ b/src/ert/gui/ertwidgets/analysismodulevariablespanel.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import Type, cast, get_args +from typing import cast, get_args from annotated_types import Ge, Gt, Le from qtpy.QtCore import Qt @@ -187,7 +187,7 @@ def createDoubleSpinBox( def valueChanged( self, variable_name: str, - variable_type: Type[bool] | Type[float], + variable_type: type[bool] | type[float], variable_control: QWidget, ) -> None: value: bool | float | None = None diff --git a/src/ert/gui/ertwidgets/checklist.py b/src/ert/gui/ertwidgets/checklist.py index 5c01005d24b..4bf1ecdac04 100644 --- a/src/ert/gui/ertwidgets/checklist.py +++ b/src/ert/gui/ertwidgets/checklist.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from qtpy.QtCore import QPoint, QSize, Qt from qtpy.QtGui import QIcon @@ -27,7 +27,7 @@ def __init__( self, model: SelectableListModel, label: str = "", - custom_filter_button: Optional[QToolButton] = None, + custom_filter_button: QToolButton | None = None, ): """ :param custom_filter_button: if needed, add a button that opens a @@ -123,15 +123,15 @@ def modelChanged(self) -> None: self.filterList(self._search_box.filter()) - def filterList(self, _filter: str) -> None: - _filter = _filter.lower() + def filterList(self, filter_: str) -> None: + filter_ = filter_.lower() for index in range(0, self._list.count()): item = self._list.item(index) assert item is not None text = item.text().lower() - if not _filter or _filter in text: + if not filter_ or filter_ in text: item.setHidden(False) else: item.setHidden(True) diff --git a/src/ert/gui/ertwidgets/closabledialog.py b/src/ert/gui/ertwidgets/closabledialog.py index dde9713043f..c6a8cdb5df3 100644 --- a/src/ert/gui/ertwidgets/closabledialog.py +++ b/src/ert/gui/ertwidgets/closabledialog.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from qtpy.QtCore import Qt from qtpy.QtWidgets import QDialog, QHBoxLayout, QPushButton, QVBoxLayout, QWidget @@ -12,7 +12,7 @@ class ClosableDialog(QDialog): def __init__( - self, title: Optional[str], widget: QWidget, parent: Optional[QWidget] = None + self, title: str | None, widget: QWidget, parent: QWidget | None = None ) -> None: QDialog.__init__(self, parent) self.setWindowTitle(title) @@ -47,7 +47,7 @@ def disableCloseButton(self) -> None: def enableCloseButton(self) -> None: self.close_button.setEnabled(True) - def keyPressEvent(self, a0: Optional[QKeyEvent]) -> None: + def keyPressEvent(self, a0: QKeyEvent | None) -> None: if self.close_button.isEnabled() or a0 is None or a0.key() != Qt.Key.Key_Escape: QDialog.keyPressEvent(self, a0) diff --git a/src/ert/gui/ertwidgets/copy_button.py b/src/ert/gui/ertwidgets/copy_button.py index 0b41d836c0a..5ee953f11a2 100644 --- a/src/ert/gui/ertwidgets/copy_button.py +++ b/src/ert/gui/ertwidgets/copy_button.py @@ -2,12 +2,7 @@ from qtpy.QtCore import QTimer from qtpy.QtGui import QIcon -from qtpy.QtWidgets import ( - QApplication, - QMessageBox, - QPushButton, - QSizePolicy, -) +from qtpy.QtWidgets import QApplication, QMessageBox, QPushButton, QSizePolicy class CopyButton(QPushButton): diff --git a/src/ert/gui/ertwidgets/copyablelabel.py b/src/ert/gui/ertwidgets/copyablelabel.py index d689f196441..02d1d47dfb8 100644 --- a/src/ert/gui/ertwidgets/copyablelabel.py +++ b/src/ert/gui/ertwidgets/copyablelabel.py @@ -1,10 +1,7 @@ from os import path from qtpy.QtCore import Qt -from qtpy.QtWidgets import ( - QHBoxLayout, - QLabel, -) +from qtpy.QtWidgets import QHBoxLayout, QLabel from .copy_button import CopyButton diff --git a/src/ert/gui/ertwidgets/create_experiment_dialog.py b/src/ert/gui/ertwidgets/create_experiment_dialog.py index 0eca7d8f5fe..ce234ea6005 100644 --- a/src/ert/gui/ertwidgets/create_experiment_dialog.py +++ b/src/ert/gui/ertwidgets/create_experiment_dialog.py @@ -1,5 +1,3 @@ -from typing import Optional - from qtpy.QtCore import ( Qt, Signal, @@ -14,11 +12,7 @@ from ert.gui.ertnotifier import ErtNotifier from ert.gui.ertwidgets import StringBox, TextModel, ValueModel -from ert.validation import ( - ExperimentValidation, - IntegerArgument, - ProperNameArgument, -) +from ert.validation import ExperimentValidation, IntegerArgument, ProperNameArgument class CreateExperimentDialog(QDialog): @@ -28,7 +22,7 @@ def __init__( self, notifier: ErtNotifier, title: str = "Create new experiment", - parent: Optional[QWidget] = None, + parent: QWidget | None = None, ) -> None: QDialog.__init__(self, parent=parent) self.setModal(True) diff --git a/src/ert/gui/ertwidgets/ensembleselector.py b/src/ert/gui/ertwidgets/ensembleselector.py index 2bbbfd2b1a2..62926ddf332 100644 --- a/src/ert/gui/ertwidgets/ensembleselector.py +++ b/src/ert/gui/ertwidgets/ensembleselector.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Optional +from collections.abc import Iterable +from typing import TYPE_CHECKING from qtpy.QtCore import Qt, Signal from qtpy.QtWidgets import QComboBox @@ -104,5 +105,5 @@ def _ensemble_list(self) -> Iterable[Ensemble]: def _on_current_index_changed(self, index: int) -> None: self.notifier.set_current_ensemble(self.itemData(index)) - def _on_global_current_ensemble_changed(self, data: Optional[Ensemble]) -> None: + def _on_global_current_ensemble_changed(self, data: Ensemble | None) -> None: self.setCurrentIndex(max(self.findData(data, Qt.ItemDataRole.UserRole), 0)) diff --git a/src/ert/gui/ertwidgets/listeditbox.py b/src/ert/gui/ertwidgets/listeditbox.py index 92aeadc1dc2..632c0056c00 100644 --- a/src/ert/gui/ertwidgets/listeditbox.py +++ b/src/ert/gui/ertwidgets/listeditbox.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, Optional +from collections.abc import Iterable from uuid import UUID from qtpy.QtCore import QSize, Qt @@ -18,9 +18,7 @@ class AutoCompleteLineEdit(QLineEdit): # http://blog.elentok.com/2011/08/autocomplete-textbox-for-multiple.html - def __init__( - self, items: Iterable[Optional[str]], parent: Optional[QWidget] = None - ): + def __init__(self, items: Iterable[str | None], parent: QWidget | None = None): super().__init__(parent) self._separators = [",", " "] @@ -52,7 +50,7 @@ def textUnderCursor(self) -> str: i -= 1 return text_under_cursor - def keyPressEvent(self, a0: Optional[QKeyEvent]) -> None: + def keyPressEvent(self, a0: QKeyEvent | None) -> None: popup = self._completer.popup() if ( popup is not None @@ -73,7 +71,7 @@ def keyPressEvent(self, a0: Optional[QKeyEvent]) -> None: if popup is not None and len(completion_prefix) == 0: popup.hide() - def __updateCompleterPopupItems(self, completionPrefix: Optional[str]) -> None: + def __updateCompleterPopupItems(self, completionPrefix: str | None) -> None: self._completer.setCompletionPrefix(completionPrefix) popup = self._completer.popup() assert popup is not None @@ -87,7 +85,7 @@ class ListEditBox(QWidget): NO_ITEMS_SPECIFIED_MSG = "The list must contain at least one item or * (for all)." DEFAULT_MSG = "A list of comma separated ensemble names or * for all." - def __init__(self, possible_items: Dict[UUID, str]) -> None: + def __init__(self, possible_items: dict[UUID, str]) -> None: QWidget.__init__(self) self._editing = True @@ -129,7 +127,7 @@ def getListText(self) -> str: text = "".join(text.split()) return text - def getItems(self) -> Dict[UUID, str]: + def getItems(self) -> dict[UUID, str]: text = self.getListText() items = text.split(",") diff --git a/src/ert/gui/ertwidgets/message_box.py b/src/ert/gui/ertwidgets/message_box.py index b3aa6d495da..0eebe8c9f82 100644 --- a/src/ert/gui/ertwidgets/message_box.py +++ b/src/ert/gui/ertwidgets/message_box.py @@ -1,5 +1,3 @@ -from typing import Optional - from qtpy.QtWidgets import ( QDialog, QDialogButtonBox, @@ -21,9 +19,9 @@ class ErtMessageBox(QDialog): def __init__( self, - text: Optional[str], - detailed_text: Optional[str], - parent: Optional[QWidget] = None, + text: str | None, + detailed_text: str | None, + parent: QWidget | None = None, ) -> None: super().__init__(parent) self.box = QDialogButtonBox( diff --git a/src/ert/gui/ertwidgets/models/activerealizationsmodel.py b/src/ert/gui/ertwidgets/models/activerealizationsmodel.py index 418ae6916f7..79859cacc23 100644 --- a/src/ert/gui/ertwidgets/models/activerealizationsmodel.py +++ b/src/ert/gui/ertwidgets/models/activerealizationsmodel.py @@ -1,4 +1,4 @@ -from typing import Collection +from collections.abc import Collection from ert.gui.ertwidgets.models.valuemodel import ValueModel from ert.validation import ActiveRange, mask_to_rangestring diff --git a/src/ert/gui/ertwidgets/models/path_model.py b/src/ert/gui/ertwidgets/models/path_model.py index e03df7e7152..0d3db929d10 100644 --- a/src/ert/gui/ertwidgets/models/path_model.py +++ b/src/ert/gui/ertwidgets/models/path_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from ert.gui.ertwidgets.models.valuemodel import ValueModel @@ -41,7 +39,7 @@ def pathMustExist(self) -> bool: def pathMustBeAbsolute(self) -> bool: return self._path_must_be_absolute - def getPath(self) -> Optional[str]: + def getPath(self) -> str | None: return self.getValue() def setPath(self, value: str) -> None: diff --git a/src/ert/gui/ertwidgets/models/selectable_list_model.py b/src/ert/gui/ertwidgets/models/selectable_list_model.py index e3c0c0e1054..0c78cf71f54 100644 --- a/src/ert/gui/ertwidgets/models/selectable_list_model.py +++ b/src/ert/gui/ertwidgets/models/selectable_list_model.py @@ -1,5 +1,3 @@ -from typing import Dict, List - from qtpy.QtCore import QObject, Signal @@ -7,12 +5,12 @@ class SelectableListModel(QObject): modelChanged = Signal() selectionChanged = Signal() - def __init__(self, items: List[str]) -> None: + def __init__(self, items: list[str]) -> None: QObject.__init__(self) - self._selection: Dict[str, bool] = {} + self._selection: dict[str, bool] = {} self._items = items - def getList(self) -> List[str]: + def getList(self) -> list[str]: return self._items def isValueSelected(self, value: str) -> bool: @@ -38,7 +36,7 @@ def selectAll(self) -> None: self.selectionChanged.emit() - def getSelectedItems(self) -> List[str]: + def getSelectedItems(self) -> list[str]: return [item for item in self.getList() if self.isValueSelected(item)] def _setSelectState(self, key: str, state: bool) -> None: diff --git a/src/ert/gui/ertwidgets/models/targetensemblemodel.py b/src/ert/gui/ertwidgets/models/targetensemblemodel.py index c12cf101563..aeb3c6a2b6e 100644 --- a/src/ert/gui/ertwidgets/models/targetensemblemodel.py +++ b/src/ert/gui/ertwidgets/models/targetensemblemodel.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from ert.config import AnalysisConfig from ert.gui.ertnotifier import ErtNotifier @@ -18,7 +18,7 @@ def __init__( notifier.ertChanged.connect(self.on_current_ensemble_changed) notifier.current_ensemble_changed.connect(self.on_current_ensemble_changed) - def setValue(self, value: Optional[str]) -> None: + def setValue(self, value: str | None) -> None: """Set a new target ensemble""" if value == self.getDefaultValue(): self._custom = False @@ -27,7 +27,7 @@ def setValue(self, value: Optional[str]) -> None: self._custom = True ValueModel.setValue(self, value) - def getDefaultValue(self) -> Optional[str]: + def getDefaultValue(self) -> str | None: ensemble_name = self.notifier.current_ensemble_name return f"{ensemble_name}_%d" diff --git a/src/ert/gui/ertwidgets/models/text_model.py b/src/ert/gui/ertwidgets/models/text_model.py index 517284b857d..4ad1f9802c4 100644 --- a/src/ert/gui/ertwidgets/models/text_model.py +++ b/src/ert/gui/ertwidgets/models/text_model.py @@ -1,5 +1,3 @@ -from typing import Optional - from ert.gui.ertwidgets.models.valuemodel import ValueModel @@ -11,7 +9,7 @@ def __init__( self.default_value = default_value super().__init__(self.getDefaultValue()) - def setValue(self, value: Optional[str]) -> None: + def setValue(self, value: str | None) -> None: if not value or not value.strip() or value == self.getDefaultValue(): ValueModel.setValue(self, self.getDefaultValue()) else: diff --git a/src/ert/gui/ertwidgets/models/valuemodel.py b/src/ert/gui/ertwidgets/models/valuemodel.py index 059866cdf41..42ea0087bb3 100644 --- a/src/ert/gui/ertwidgets/models/valuemodel.py +++ b/src/ert/gui/ertwidgets/models/valuemodel.py @@ -1,20 +1,18 @@ -from typing import Optional - from qtpy.QtCore import QObject, Signal, Slot class ValueModel(QObject): valueChanged = Signal(str) - def __init__(self, value: Optional[str] = ""): + def __init__(self, value: str | None = ""): super().__init__() self._value = value - def getValue(self) -> Optional[str]: + def getValue(self) -> str | None: return self._value @Slot(str) - def setValue(self, value: Optional[str]) -> None: + def setValue(self, value: str | None) -> None: self._value = value self.valueChanged.emit(value) diff --git a/src/ert/gui/ertwidgets/pathchooser.py b/src/ert/gui/ertwidgets/pathchooser.py index d8636cffafc..79c9d072f04 100644 --- a/src/ert/gui/ertwidgets/pathchooser.py +++ b/src/ert/gui/ertwidgets/pathchooser.py @@ -2,7 +2,7 @@ import os import re -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING from qtpy.QtCore import QSize from qtpy.QtGui import QIcon @@ -64,7 +64,7 @@ def __init__(self, model: PathModel) -> None: self.setLayout(layout) self.getPathFromModel() - def isPathValid(self, path: str) -> Tuple[bool, str]: + def isPathValid(self, path: str) -> tuple[bool, str]: path = path.strip() path_exists = os.path.exists(path) is_file = os.path.isfile(path) diff --git a/src/ert/gui/ertwidgets/searchbox.py b/src/ert/gui/ertwidgets/searchbox.py index 2bbc95fa7a1..473387257f0 100644 --- a/src/ert/gui/ertwidgets/searchbox.py +++ b/src/ert/gui/ertwidgets/searchbox.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from qtpy.QtCore import Qt, Signal from qtpy.QtGui import QColor, QFocusEvent, QKeyEvent @@ -54,15 +54,15 @@ def exitSearch(self) -> None: if not self.text(): self.presentSearch() - def focusInEvent(self, a0: Optional[QFocusEvent]) -> None: + def focusInEvent(self, a0: QFocusEvent | None) -> None: QLineEdit.focusInEvent(self, a0) self.enterSearch() - def focusOutEvent(self, a0: Optional[QFocusEvent]) -> None: + def focusOutEvent(self, a0: QFocusEvent | None) -> None: QLineEdit.focusOutEvent(self, a0) self.exitSearch() - def keyPressEvent(self, a0: Optional[QKeyEvent]) -> None: + def keyPressEvent(self, a0: QKeyEvent | None) -> None: if a0 is not None and a0.key() == Qt.Key.Key_Escape: self.clear() self.clearFocus() diff --git a/src/ert/gui/ertwidgets/stringbox.py b/src/ert/gui/ertwidgets/stringbox.py index 2efbf6d493c..60f56bfc634 100644 --- a/src/ert/gui/ertwidgets/stringbox.py +++ b/src/ert/gui/ertwidgets/stringbox.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from qtpy.QtGui import QPalette from qtpy.QtWidgets import QLineEdit @@ -34,7 +34,7 @@ def __init__( QLineEdit.__init__(self) self.setMinimumWidth(minimum_width) self._validation = ValidationSupport(self) - self._validator: Optional[ArgumentDefinition] = None + self._validator: ArgumentDefinition | None = None self._model = model self._enable_validation = True @@ -81,7 +81,7 @@ def emitChange(self, q_string: Any) -> None: def stringBoxChanged(self) -> None: """Called whenever the contents of the editline changes.""" - text: Optional[str] = self.text() + text: str | None = self.text() if not text: text = None diff --git a/src/ert/gui/ertwidgets/textbox.py b/src/ert/gui/ertwidgets/textbox.py index b8b0837fb1e..5362ae4d148 100644 --- a/src/ert/gui/ertwidgets/textbox.py +++ b/src/ert/gui/ertwidgets/textbox.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from qtpy.QtGui import QPalette from qtpy.QtWidgets import QTextEdit @@ -27,7 +27,7 @@ def __init__( QTextEdit.__init__(self) self.setMinimumWidth(minimum_width) self._validation = ValidationSupport(self) - self._validator: Optional[StringDefinition] = None + self._validator: StringDefinition | None = None self._model = model self._enable_validation = True @@ -68,7 +68,7 @@ def emitChange(self, q_string: Any) -> None: def textBoxChanged(self) -> None: """Called whenever the contents of the textbox changes.""" - text: Optional[str] = self.toPlainText() + text: str | None = self.toPlainText() if not text: text = None diff --git a/src/ert/gui/ertwidgets/validationsupport.py b/src/ert/gui/ertwidgets/validationsupport.py index b85f936f09e..8e7bdeaef03 100644 --- a/src/ert/gui/ertwidgets/validationsupport.py +++ b/src/ert/gui/ertwidgets/validationsupport.py @@ -1,7 +1,7 @@ from __future__ import annotations import html -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from qtpy.QtCore import ( QObject, @@ -10,13 +10,7 @@ Signal, ) from qtpy.QtGui import QColor -from qtpy.QtWidgets import ( - QFrame, - QLabel, - QSizePolicy, - QVBoxLayout, - QWidget, -) +from qtpy.QtWidgets import QFrame, QLabel, QSizePolicy, QVBoxLayout, QWidget if TYPE_CHECKING: from qtpy.QtCore import QEvent @@ -82,15 +76,15 @@ def __init__(self, validation_target: QWidget) -> None: QObject.__init__(self) self._validation_target = validation_target - self._validation_message: Optional[str] = None - self._validation_type: Optional[str] = None + self._validation_message: str | None = None + self._validation_type: str | None = None self._error_popup = ErrorPopup() self._originalEnterEvent = validation_target.enterEvent self._originalLeaveEvent = validation_target.leaveEvent self._originalHideEvent = validation_target.hideEvent - def enterEvent(a0: Optional[QEvent]) -> None: + def enterEvent(a0: QEvent | None) -> None: self._originalEnterEvent(a0) if not self.isValid(): @@ -101,7 +95,7 @@ def enterEvent(a0: Optional[QEvent]) -> None: validation_target.enterEvent = enterEvent # type: ignore[method-assign] - def leaveEvent(a0: Optional[QEvent]) -> None: + def leaveEvent(a0: QEvent | None) -> None: self._originalLeaveEvent(a0) if self._error_popup is not None: @@ -109,7 +103,7 @@ def leaveEvent(a0: Optional[QEvent]) -> None: validation_target.leaveEvent = leaveEvent # type: ignore[method-assign] - def hideEvent(a0: Optional[QHideEvent]) -> None: + def hideEvent(a0: QHideEvent | None) -> None: self._error_popup.hide() self._originalHideEvent(a0) diff --git a/src/ert/gui/main.py b/src/ert/gui/main.py index dbcb6ca8b18..b4d412bb84a 100755 --- a/src/ert/gui/main.py +++ b/src/ert/gui/main.py @@ -5,7 +5,6 @@ from collections import Counter from importlib.resources import files from signal import SIG_DFL, SIGINT, signal -from typing import Optional, Tuple from qtpy.QtCore import QDir from qtpy.QtGui import QIcon @@ -30,7 +29,7 @@ from .suggestor import Suggestor -def run_gui(args: Namespace, plugin_manager: Optional[ErtPluginManager] = None) -> int: +def run_gui(args: Namespace, plugin_manager: ErtPluginManager | None = None) -> int: # Replace Python's exception handler for SIGINT with the system default. # # Python's SIGINT handler is the one that raises KeyboardInterrupt. This is @@ -68,8 +67,8 @@ def show_window() -> int: def _start_initial_gui_window( args: Namespace, log_handler: GUILogHandler, - plugin_manager: Optional[ErtPluginManager] = None, -) -> Tuple[QWidget, Optional[str]]: + plugin_manager: ErtPluginManager | None = None, +) -> tuple[QWidget, str | None]: # Create logger inside function to make sure all handlers have been added to # the root-logger. logger = logging.getLogger(__name__) @@ -119,17 +118,17 @@ def _start_initial_gui_window( for msg in validation_messages.warnings: logger.info(f"Warning shown in gui '{msg}'") - _main_window = _setup_main_window( + main_window = _setup_main_window( ert_config, args, log_handler, storage, plugin_manager ) if validation_messages.warnings or validation_messages.deprecations: def continue_action() -> None: - _main_window.show() - _main_window.activateWindow() - _main_window.raise_() - _main_window.adjustSize() + main_window.show() + main_window.activateWindow() + main_window.raise_() + main_window.adjustSize() suggestor = Suggestor( validation_messages.errors, @@ -138,14 +137,14 @@ def continue_action() -> None: continue_action, plugin_manager.get_help_links() if plugin_manager is not None else {}, ) - suggestor.notifier = _main_window.notifier + suggestor.notifier = main_window.notifier return ( suggestor, ert_config.ens_path, ) else: return ( - _main_window, + main_window, ert_config.ens_path, ) @@ -155,7 +154,7 @@ def _setup_main_window( args: Namespace, log_handler: GUILogHandler, storage: Storage, - plugin_manager: Optional[ErtPluginManager] = None, + plugin_manager: ErtPluginManager | None = None, ) -> ErtMainWindow: # window reference must be kept until app.exec returns: window = ErtMainWindow(args.config, ert_config, plugin_manager, log_handler) diff --git a/src/ert/gui/main_window.py b/src/ert/gui/main_window.py index 5f8a9f046b9..497e46a364f 100644 --- a/src/ert/gui/main_window.py +++ b/src/ert/gui/main_window.py @@ -194,9 +194,7 @@ def slot_add_widget(self, run_dialog: RunDialog) -> None: widget.setVisible(False) run_dialog.setParent(self) - date_time = datetime.datetime.now(datetime.timezone.utc).strftime( - "%Y-%d-%m %H:%M:%S" - ) + date_time = datetime.datetime.now(datetime.UTC).strftime("%Y-%d-%m %H:%M:%S") experiment_type = run_dialog._run_model.name() simulation_id = experiment_type + " : " + date_time self.central_panels_map[simulation_id] = run_dialog diff --git a/src/ert/gui/model/node.py b/src/ert/gui/model/node.py index 345129ee2c2..e8baa8ee5c5 100644 --- a/src/ert/gui/model/node.py +++ b/src/ert/gui/model/node.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional, cast +from typing import cast from qtpy.QtGui import QColor @@ -12,11 +12,11 @@ @dataclass class _Node(ABC): id_: str - parent: Optional[RootNode | IterNode | RealNode] = None + parent: RootNode | IterNode | RealNode | None = None children: ( dict[str, IterNode] | dict[str, RealNode] | dict[str, ForwardModelStepNode] ) = field(default_factory=dict) - _index: Optional[int] = None + _index: int | None = None def __repr__(self) -> str: parent = "no " if self.parent is None else "" @@ -40,7 +40,7 @@ def row(self) -> int: class RootNode(_Node): parent: None = field(default=None, init=False) children: dict[str, IterNode] = field(default_factory=dict) - max_memory_usage: Optional[int] = None + max_memory_usage: int | None = None def add_child(self, node: _Node) -> None: node = cast(IterNode, node) @@ -50,13 +50,13 @@ def add_child(self, node: _Node) -> None: @dataclass class IterNodeData: - index: Optional[str] = None - status: Optional[str] = None + index: str | None = None + status: str | None = None @dataclass class IterNode(_Node): - parent: Optional[RootNode] = None + parent: RootNode | None = None data: IterNodeData = field(default_factory=IterNodeData) children: dict[str, RealNode] = field(default_factory=dict) @@ -68,20 +68,20 @@ def add_child(self, node: _Node) -> None: @dataclass class RealNodeData: - status: Optional[str] = None - active: Optional[bool] = False + status: str | None = None + active: bool | None = False fm_step_status_color_by_id: dict[str, QColor] = field(default_factory=dict) - real_status_color: Optional[QColor] = None - current_memory_usage: Optional[int] = None - max_memory_usage: Optional[int] = None - exec_hosts: Optional[str] = None - stderr: Optional[str] = None - message: Optional[str] = None + real_status_color: QColor | None = None + current_memory_usage: int | None = None + max_memory_usage: int | None = None + exec_hosts: str | None = None + stderr: str | None = None + message: str | None = None @dataclass class RealNode(_Node): - parent: Optional[IterNode] = None + parent: IterNode | None = None data: RealNodeData = field(default_factory=RealNodeData) children: dict[str, ForwardModelStepNode] = field(default_factory=dict) @@ -93,7 +93,7 @@ def add_child(self, node: _Node) -> None: @dataclass class ForwardModelStepNode(_Node): - parent: Optional[RealNode] + parent: RealNode | None data: FMStepSnapshot = field(default_factory=lambda: FMStepSnapshot()) # noqa: PLW0108 def add_child(self, node: _Node) -> None: diff --git a/src/ert/gui/model/snapshot.py b/src/ert/gui/model/snapshot.py index 8e8f7ae63a8..7fd96ddceb1 100644 --- a/src/ert/gui/model/snapshot.py +++ b/src/ert/gui/model/snapshot.py @@ -1,8 +1,9 @@ import logging from collections import defaultdict +from collections.abc import Sequence from contextlib import ExitStack from datetime import datetime, timedelta -from typing import Any, Final, Sequence, overload +from typing import Any, Final, overload from qtpy.QtCore import QAbstractItemModel, QModelIndex, QObject, QSize, Qt, QVariant from qtpy.QtGui import QColor, QFont @@ -419,9 +420,9 @@ def _fm_step_data( data_name = FM_STEP_COLUMNS[index.column()] if data_name in [ids.MAX_MEMORY_USAGE]: data = node.data - _bytes: str | None = data.get(data_name) # type: ignore - if _bytes: - return byte_with_unit(float(_bytes)) + bytes_: str | None = data.get(data_name) # type: ignore + if bytes_: + return byte_with_unit(float(bytes_)) if data_name in [ids.STDOUT, ids.STDERR]: if not file_has_content(index.data(FileRole)): diff --git a/src/ert/gui/simulation/combobox_with_description.py b/src/ert/gui/simulation/combobox_with_description.py index 51784a1a42e..de393bdc564 100644 --- a/src/ert/gui/simulation/combobox_with_description.py +++ b/src/ert/gui/simulation/combobox_with_description.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from qtpy.QtCore import QModelIndex, QPoint, QSize from qtpy.QtGui import QColor, QRegion @@ -26,8 +26,8 @@ def __init__( label: str, description: str, enabled: bool = True, - parent: Optional[QWidget] = None, - group: Optional[str] = None, + parent: QWidget | None = None, + group: str | None = None, ) -> None: super().__init__(parent) layout = QVBoxLayout() @@ -115,12 +115,12 @@ def sizeHint(self, option: QStyleOptionViewItem, index: QModelIndex) -> QSize: class QComboBoxWithDescription(QComboBox): - def __init__(self, parent: Optional[QWidget] = None) -> None: + def __init__(self, parent: QWidget | None = None) -> None: super().__init__(parent) self.setItemDelegate(_ComboBoxWithDescriptionDelegate(self)) def addDescriptionItem( - self, label: Optional[str], description: Any, group: Optional[str] = None + self, label: str | None, description: Any, group: str | None = None ) -> None: super().addItem(label) model = self.model() diff --git a/src/ert/gui/simulation/experiment_config_panel.py b/src/ert/gui/simulation/experiment_config_panel.py index 07cb7af316c..08e8fe2be52 100644 --- a/src/ert/gui/simulation/experiment_config_panel.py +++ b/src/ert/gui/simulation/experiment_config_panel.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from qtpy.QtCore import Signal, Slot from qtpy.QtWidgets import QWidget @@ -12,12 +12,12 @@ class ExperimentConfigPanel(QWidget): simulationConfigurationChanged = Signal() - def __init__(self, simulation_model: Type[BaseRunModel]): + def __init__(self, simulation_model: type[BaseRunModel]): QWidget.__init__(self) self.setContentsMargins(10, 10, 10, 10) self.__simulation_model = simulation_model - def get_experiment_type(self) -> Type[BaseRunModel]: + def get_experiment_type(self) -> type[BaseRunModel]: return self.__simulation_model def isConfigurationValid(self) -> bool: diff --git a/src/ert/gui/simulation/experiment_panel.py b/src/ert/gui/simulation/experiment_panel.py index 1a9e5c86510..3e4a494f872 100644 --- a/src/ert/gui/simulation/experiment_panel.py +++ b/src/ert/gui/simulation/experiment_panel.py @@ -7,7 +7,7 @@ from datetime import datetime from pathlib import Path from queue import SimpleQueue -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from qtpy.QtCore import QSize, Qt, Signal from qtpy.QtGui import QIcon, QStandardItemModel @@ -138,7 +138,7 @@ def __init__( layout.addWidget(self._experiment_stack) - self._experiment_widgets: dict[Type[BaseRunModel], QWidget] = OrderedDict() + self._experiment_widgets: dict[type[BaseRunModel], QWidget] = OrderedDict() self.addExperimentConfigPanel( SingleTestRunPanel(run_path, notifier), True, @@ -263,17 +263,15 @@ def run_experiment(self) -> None: msg_box.setText("Run experiments") msg_box.setInformativeText( - ( - "ERT is running in an existing runpath.\n\n" - "Please be aware of the following:\n" - "- Previously generated results " - "might be overwritten.\n" - "- Previously generated files might " - "be used if not configured correctly.\n" - f"- {model.get_number_of_existing_runpaths()} out of {model.get_number_of_active_realizations()} realizations " - "are running in existing runpaths.\n" - "Are you sure you want to continue?" - ) + "ERT is running in an existing runpath.\n\n" + "Please be aware of the following:\n" + "- Previously generated results " + "might be overwritten.\n" + "- Previously generated files might " + "be used if not configured correctly.\n" + f"- {model.get_number_of_existing_runpaths()} out of {model.get_number_of_active_realizations()} realizations " + "are running in existing runpaths.\n" + "Are you sure you want to continue?" ) delete_runpath_checkbox = QCheckBox() @@ -300,7 +298,7 @@ def run_experiment(self) -> None: msg_box.setIcon(QMessageBox.Warning) msg_box.setText("ERT could not delete the existing runpath") msg_box.setInformativeText( - (f"{e}\n\n" "Continue without deleting the runpath?") + f"{e}\n\n" "Continue without deleting the runpath?" ) msg_box.setStandardButtons(QMessageBox.Yes | QMessageBox.No) msg_box.setDefaultButton(QMessageBox.No) diff --git a/src/ert/gui/simulation/multiple_data_assimilation_panel.py b/src/ert/gui/simulation/multiple_data_assimilation_panel.py index 5e550b41951..a629c77f09a 100644 --- a/src/ert/gui/simulation/multiple_data_assimilation_panel.py +++ b/src/ert/gui/simulation/multiple_data_assimilation_panel.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any from qtpy.QtCore import Slot from qtpy.QtGui import QFont @@ -39,7 +39,7 @@ class Arguments: mode: str target_ensemble: str realizations: str - weights: List[float] + weights: list[float] restart_run: bool prior_ensemble_id: str # UUID not serializable in json experiment_name: str diff --git a/src/ert/gui/simulation/queue_emitter.py b/src/ert/gui/simulation/queue_emitter.py index 812c6cd8848..461e3b6dc35 100644 --- a/src/ert/gui/simulation/queue_emitter.py +++ b/src/ert/gui/simulation/queue_emitter.py @@ -4,15 +4,10 @@ from contextlib import suppress from queue import Empty, SimpleQueue from time import sleep -from typing import Optional from qtpy.QtCore import QObject, Signal, Slot -from ert.ensemble_evaluator import ( - EndEvent, - FullSnapshotEvent, - SnapshotUpdateEvent, -) +from ert.ensemble_evaluator import EndEvent, FullSnapshotEvent, SnapshotUpdateEvent from ert.gui.model.snapshot import SnapshotModel from ert.run_models import StatusEvents @@ -28,7 +23,7 @@ class QueueEmitter(QObject): def __init__( self, event_queue: SimpleQueue[StatusEvents], - parent: Optional[QObject] = None, + parent: QObject | None = None, ): super().__init__(parent) logger.debug("init QueueEmitter") @@ -51,7 +46,7 @@ def consume_and_emit(self) -> None: # pre-rendering in this thread to avoid work in main rendering thread if ( - isinstance(event, (FullSnapshotEvent, SnapshotUpdateEvent)) + isinstance(event, FullSnapshotEvent | SnapshotUpdateEvent) and event.snapshot ): SnapshotModel.prerender(event.snapshot) diff --git a/src/ert/gui/simulation/run_dialog.py b/src/ert/gui/simulation/run_dialog.py index 8958bdc741d..82e76aa0295 100644 --- a/src/ert/gui/simulation/run_dialog.py +++ b/src/ert/gui/simulation/run_dialog.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging +from collections.abc import Callable from pathlib import Path from queue import SimpleQueue -from typing import Callable, Optional from qtpy.QtCore import QModelIndex, QSize, Qt, QThread, QTimer, Signal, Slot from qtpy.QtGui import ( @@ -176,8 +176,8 @@ def __init__( run_model: BaseRunModel, event_queue: SimpleQueue[StatusEvents], notifier: ErtNotifier, - parent: Optional[QWidget] = None, - output_path: Optional[Path] = None, + parent: QWidget | None = None, + output_path: Path | None = None, ): QFrame.__init__(self, parent) self.output_path = output_path @@ -190,7 +190,7 @@ def __init__( self._run_model = run_model self._event_queue = event_queue self._notifier = notifier - self.fail_msg_box: Optional[ErtMessageBox] = None + self.fail_msg_box: ErtMessageBox | None = None self._ticker = QTimer(self) self._ticker.timeout.connect(self._on_ticker) diff --git a/src/ert/gui/simulation/view/realization.py b/src/ert/gui/simulation/view/realization.py index d09d540be60..6859829e362 100644 --- a/src/ert/gui/simulation/view/realization.py +++ b/src/ert/gui/simulation/view/realization.py @@ -1,5 +1,3 @@ -from typing import Optional - from qtpy.QtCore import ( QAbstractItemModel, QEvent, @@ -34,10 +32,10 @@ class RealizationWidget(QWidget): - def __init__(self, _it: int, parent: Optional[QWidget] = None) -> None: + def __init__(self, it: int, parent: QWidget | None = None) -> None: super().__init__(parent) - self._iter = _it + self._iter = it self._delegate_size = QSize(90, 90) self._real_view = QListView(self) @@ -97,7 +95,7 @@ def __init__(self, size: QSize, parent: QObject) -> None: def paint( self, - painter: Optional[QPainter], + painter: QPainter | None, option: QStyleOptionViewItem, index: QModelIndex, ) -> None: @@ -149,7 +147,7 @@ def paint( def sizeHint(self, option: QStyleOptionViewItem, index: QModelIndex) -> QSize: return self._size - def eventFilter(self, object: Optional[QObject], event: Optional[QEvent]) -> bool: + def eventFilter(self, object: QObject | None, event: QEvent | None) -> bool: if event.type() == QEvent.Type.ToolTip: # type: ignore mouse_pos = event.pos() + self.adjustment_point_for_job_rect_margin # type: ignore parent: RealizationWidget = self.parent() # type: ignore diff --git a/src/ert/gui/simulation/view/update.py b/src/ert/gui/simulation/view/update.py index d295acf6f16..7f3716fa4b4 100644 --- a/src/ert/gui/simulation/view/update.py +++ b/src/ert/gui/simulation/view/update.py @@ -3,7 +3,6 @@ import math import time from datetime import timedelta -from typing import Optional import humanize from qtpy.QtCore import Qt, Slot @@ -39,7 +38,7 @@ class UpdateLogTable(QTableWidget): - def __init__(self, data: DataSection, parent: Optional[QWidget] = None): + def __init__(self, data: DataSection, parent: QWidget | None = None): super().__init__(parent) self.setColumnCount(len(data.header)) @@ -55,9 +54,9 @@ def __init__(self, data: DataSection, parent: Optional[QWidget] = None): for j, val in enumerate(row): self.setItem(i, j, QTableWidgetItem(str(val))) - def keyPressEvent(self, e: Optional[QKeyEvent]) -> None: + def keyPressEvent(self, e: QKeyEvent | None) -> None: if e is not None and e.matches(QKeySequence.Copy): - stream = str() + stream = "" for i in self.selectedIndexes(): item = self.itemFromIndex(i) assert item is not None @@ -78,7 +77,7 @@ def keyPressEvent(self, e: Optional[QKeyEvent]) -> None: class UpdateWidget(QWidget): - def __init__(self, iteration: int, parent: Optional[QWidget] = None) -> None: + def __init__(self, iteration: int, parent: QWidget | None = None) -> None: super().__init__(parent) self._iteration = iteration diff --git a/src/ert/gui/suggestor/suggestor.py b/src/ert/gui/suggestor/suggestor.py index 27754241f5e..99221baf2ab 100644 --- a/src/ert/gui/suggestor/suggestor.py +++ b/src/ert/gui/suggestor/suggestor.py @@ -4,7 +4,8 @@ import logging import webbrowser from collections import defaultdict -from typing import TYPE_CHECKING, Callable, Sequence +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING from qtpy.QtCore import Qt from qtpy.QtGui import QCursor diff --git a/src/ert/gui/tools/design_matrix/design_matrix_panel.py b/src/ert/gui/tools/design_matrix/design_matrix_panel.py index 2644c8d2ddc..e92e81063c6 100644 --- a/src/ert/gui/tools/design_matrix/design_matrix_panel.py +++ b/src/ert/gui/tools/design_matrix/design_matrix_panel.py @@ -1,5 +1,3 @@ -from typing import Optional - import pandas as pd from qtpy.QtGui import QStandardItem, QStandardItemModel from qtpy.QtWidgets import QDialog, QTableView, QVBoxLayout, QWidget @@ -10,7 +8,7 @@ def __init__( self, design_matrix_df: pd.DataFrame, filename: str, - parent: Optional[QWidget] = None, + parent: QWidget | None = None, ) -> None: super().__init__(parent) diff --git a/src/ert/gui/tools/event_viewer/panel.py b/src/ert/gui/tools/event_viewer/panel.py index da458cf7ba5..3dfd06fb654 100644 --- a/src/ert/gui/tools/event_viewer/panel.py +++ b/src/ert/gui/tools/event_viewer/panel.py @@ -1,6 +1,6 @@ import logging +from collections.abc import Iterator from contextlib import contextmanager -from typing import Iterator from qtpy import QtCore from qtpy.QtCore import QObject diff --git a/src/ert/gui/tools/event_viewer/tool.py b/src/ert/gui/tools/event_viewer/tool.py index 2b49ed7e288..c083b06a79d 100644 --- a/src/ert/gui/tools/event_viewer/tool.py +++ b/src/ert/gui/tools/event_viewer/tool.py @@ -1,5 +1,3 @@ -from typing import Optional - from qtpy.QtCore import QObject, Slot from qtpy.QtGui import QIcon @@ -9,9 +7,7 @@ class EventViewerTool(Tool, QObject): - def __init__( - self, gui_handler: GUILogHandler, config_filename: Optional[str] = None - ): + def __init__(self, gui_handler: GUILogHandler, config_filename: str | None = None): super().__init__( "Event viewer", QIcon("img:notifications.svg"), diff --git a/src/ert/gui/tools/export/export_panel.py b/src/ert/gui/tools/export/export_panel.py index 71638fa76a7..1086f5ae66e 100644 --- a/src/ert/gui/tools/export/export_panel.py +++ b/src/ert/gui/tools/export/export_panel.py @@ -1,16 +1,11 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from qtpy.QtWidgets import QCheckBox, QWidget -from ert.gui.ertwidgets import ( - CustomDialog, - ListEditBox, - PathChooser, - PathModel, -) +from ert.gui.ertwidgets import CustomDialog, ListEditBox, PathChooser, PathModel if TYPE_CHECKING: from ert.config import ErtConfig @@ -22,7 +17,7 @@ def __init__( self, ert_config: ErtConfig, storage: LocalStorage, - parent: Optional[QWidget] = None, + parent: QWidget | None = None, ) -> None: self.storage = storage description = "The CSV export requires some information before it starts:" @@ -60,7 +55,7 @@ def __init__( self.addButtons() @property - def output_path(self) -> Optional[str]: + def output_path(self) -> str | None: return self.output_path_model.getPath() @property @@ -77,7 +72,7 @@ def ensemble_data_as_json(self) -> str: return json.dumps(ensembles) @property - def design_matrix_path(self) -> Optional[str]: + def design_matrix_path(self) -> str | None: path = self.design_matrix_path_model.getPath() if not path or not path.strip(): path = None diff --git a/src/ert/gui/tools/file/file_dialog.py b/src/ert/gui/tools/file/file_dialog.py index 91b0b5757bf..8ab5d21d60b 100644 --- a/src/ert/gui/tools/file/file_dialog.py +++ b/src/ert/gui/tools/file/file_dialog.py @@ -1,5 +1,4 @@ from math import floor -from typing import Optional from qtpy.QtCore import QSize, Qt, QThread from qtpy.QtGui import QClipboard, QFontDatabase, QTextCursor, QTextOption @@ -46,7 +45,7 @@ def __init__( job_number: int, realization: int, iteration: int, - parent: Optional[QWidget] = None, + parent: QWidget | None = None, ) -> None: super().__init__(parent) @@ -58,7 +57,7 @@ def __init__( try: # We take care to close this file in _quit_thread() - self._file = open(file_name, "r", encoding="utf-8") # noqa: SIM115 + self._file = open(file_name, encoding="utf-8") # noqa: SIM115 except OSError as error: self._mb = QMessageBox( QMessageBox.Critical, diff --git a/src/ert/gui/tools/file/file_update_worker.py b/src/ert/gui/tools/file/file_update_worker.py index a6bb511fa97..25b9343853d 100644 --- a/src/ert/gui/tools/file/file_update_worker.py +++ b/src/ert/gui/tools/file/file_update_worker.py @@ -1,5 +1,4 @@ from io import TextIOWrapper -from typing import Optional from qtpy.QtCore import QObject, QTimer, Signal, Slot @@ -11,10 +10,10 @@ class FileUpdateWorker(QObject): read = Signal(str) - def __init__(self, file: TextIOWrapper, parent: Optional[QObject] = None) -> None: + def __init__(self, file: TextIOWrapper, parent: QObject | None = None) -> None: super().__init__(parent) self._file = file - self._timer: Optional[QTimer] = None + self._timer: QTimer | None = None @Slot() def stop(self) -> None: diff --git a/src/ert/gui/tools/load_results/load_results_tool.py b/src/ert/gui/tools/load_results/load_results_tool.py index da514a0163c..faf5a4b58bd 100644 --- a/src/ert/gui/tools/load_results/load_results_tool.py +++ b/src/ert/gui/tools/load_results/load_results_tool.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from qtpy.QtGui import QIcon @@ -16,8 +16,8 @@ def __init__(self, facade: LibresFacade, notifier: ErtNotifier) -> None: "Load results manually", QIcon("img:upload.svg"), ) - self._import_widget: Optional[LoadResultsPanel] = None - self._dialog: Optional[ClosableDialog] = None + self._import_widget: LoadResultsPanel | None = None + self._dialog: ClosableDialog | None = None self._notifier = notifier def trigger(self) -> None: diff --git a/src/ert/gui/tools/manage_experiments/storage_info_widget.py b/src/ert/gui/tools/manage_experiments/storage_info_widget.py index a8fb0390e72..64fb5b4b211 100644 --- a/src/ert/gui/tools/manage_experiments/storage_info_widget.py +++ b/src/ert/gui/tools/manage_experiments/storage_info_widget.py @@ -1,6 +1,5 @@ import json from enum import IntEnum -from typing import Optional import polars import seaborn as sns @@ -48,7 +47,7 @@ class _EnsembleWidgetTabs(IntEnum): class _ExperimentWidget(QWidget): def __init__(self) -> None: QWidget.__init__(self) - self._experiment: Optional[Experiment] = None + self._experiment: Experiment | None = None self._responses_text_edit = QTextEdit() self._responses_text_edit.setReadOnly(True) @@ -114,7 +113,7 @@ def setExperiment(self, experiment: Experiment) -> None: class _EnsembleWidget(QWidget): def __init__(self) -> None: QWidget.__init__(self) - self._ensemble: Optional[Ensemble] = None + self._ensemble: Ensemble | None = None info_frame = QFrame() self._name_label = QLabel() @@ -248,7 +247,7 @@ def _try_render_scaled_obs() -> None: [polars.col("scaling_factor").product()] ) joined_small = joined_small.with_columns( - (joined_small["std"] * joined_small["scaling_factor"]) + joined_small["std"] * joined_small["scaling_factor"] ) ax.errorbar( diff --git a/src/ert/gui/tools/manage_experiments/storage_model.py b/src/ert/gui/tools/manage_experiments/storage_model.py index 216775e6d7c..54ac392adfa 100644 --- a/src/ert/gui/tools/manage_experiments/storage_model.py +++ b/src/ert/gui/tools/manage_experiments/storage_model.py @@ -3,13 +3,7 @@ from uuid import UUID import humanize -from qtpy.QtCore import ( - QAbstractItemModel, - QModelIndex, - QObject, - Qt, - Slot, -) +from qtpy.QtCore import QAbstractItemModel, QModelIndex, QObject, Qt, Slot from qtpy.QtWidgets import QApplication from typing_extensions import override diff --git a/src/ert/gui/tools/manage_experiments/storage_widget.py b/src/ert/gui/tools/manage_experiments/storage_widget.py index 3d419ecc92b..a062254607a 100644 --- a/src/ert/gui/tools/manage_experiments/storage_widget.py +++ b/src/ert/gui/tools/manage_experiments/storage_widget.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from qtpy.QtCore import ( QAbstractItemModel, diff --git a/src/ert/gui/tools/plot/customize/color_chooser.py b/src/ert/gui/tools/plot/customize/color_chooser.py index 5e195794706..de8f7495915 100644 --- a/src/ert/gui/tools/plot/customize/color_chooser.py +++ b/src/ert/gui/tools/plot/customize/color_chooser.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple - from qtpy.QtCore import QRect, QSize, Signal, Slot from qtpy.QtGui import QColor, QMouseEvent, QPainter, QPaintEvent from qtpy.QtWidgets import QColorDialog, QFrame @@ -43,13 +41,13 @@ def color(self) -> QColor: return self._color @color.setter - def color(self, color: Tuple[str, float]) -> None: + def color(self, color: tuple[str, float]) -> None: new_color = QColor(color[0]) new_color.setAlphaF(color[1]) self._color = new_color self.update() - def paintEvent(self, event: Optional[QPaintEvent]) -> None: + def paintEvent(self, event: QPaintEvent | None) -> None: """Paints the box""" painter = QPainter(self) rect = self.contentsRect() @@ -69,7 +67,7 @@ def paintEvent(self, event: Optional[QPaintEvent]) -> None: QFrame.paintEvent(self, event) - def mouseReleaseEvent(self, event: Optional[QMouseEvent]) -> None: + def mouseReleaseEvent(self, event: QMouseEvent | None) -> None: if event: self.mouseRelease.emit() return super().mouseReleaseEvent(event) diff --git a/src/ert/gui/tools/plot/customize/customization_view.py b/src/ert/gui/tools/plot/customize/customization_view.py index 04877c2ad41..24dd5530e00 100644 --- a/src/ert/gui/tools/plot/customize/customization_view.py +++ b/src/ert/gui/tools/plot/customize/customization_view.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Optional +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from qtpy.QtWidgets import ( QCheckBox, @@ -27,14 +28,14 @@ def __init__(self) -> None: self.setLayout(self._layout) self._widgets: dict[str, QWidget] = {} - def addRow(self, title: Optional[str], widget: Optional[QWidget]) -> None: + def addRow(self, title: str | None, widget: QWidget | None) -> None: self._layout.addRow(title, widget) def addLineEdit( self, attribute_name: str, title: str, - tool_tip: Optional[str] = None, + tool_tip: str | None = None, placeholder: str = "", ) -> None: self[attribute_name] = ClearableLineEdit(placeholder=placeholder) @@ -43,13 +44,13 @@ def addLineEdit( if tool_tip is not None: self[attribute_name].setToolTip(tool_tip) - def getter(self: Any) -> Optional[str]: - value: Optional[str] = str(self[attribute_name].text()) + def getter(self: Any) -> str | None: + value: str | None = str(self[attribute_name].text()) if not value: value = None return value - def setter(self: Any, value: Optional[str]) -> None: + def setter(self: Any, value: str | None) -> None: if value is None: value = "" self[attribute_name].setText(str(value)) @@ -57,7 +58,7 @@ def setter(self: Any, value: Optional[str]) -> None: self.updateProperty(attribute_name, getter, setter) def addCheckBox( - self, attribute_name: str, title: str, tool_tip: Optional[str] = None + self, attribute_name: str, title: str, tool_tip: str | None = None ) -> None: self[attribute_name] = QCheckBox() self.addRow(title, self[attribute_name]) @@ -77,7 +78,7 @@ def addSpinBox( self, attribute_name: str, title: str, - tool_tip: Optional[str] = None, + tool_tip: str | None = None, min_value: int = 1, max_value: int = 10, single_step: int = 1, @@ -110,7 +111,7 @@ def addStyleChooser( self, attribute_name: str, title: str, - tool_tip: Optional[str] = None, + tool_tip: str | None = None, line_style_set: str = STYLESET_DEFAULT, ) -> None: style_chooser = StyleChooser(line_style_set=line_style_set) diff --git a/src/ert/gui/tools/plot/customize/customize_plot_dialog.py b/src/ert/gui/tools/plot/customize/customize_plot_dialog.py index 40ce107e04a..ca4bbf520a4 100644 --- a/src/ert/gui/tools/plot/customize/customize_plot_dialog.py +++ b/src/ert/gui/tools/plot/customize/customize_plot_dialog.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Iterator +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING from qtpy.QtCore import QObject, Qt, Signal from qtpy.QtGui import QIcon, QKeyEvent @@ -304,7 +305,7 @@ def addTab(self, attribute_name: str, title: str, widget: QWidget) -> None: self._tab_map[attribute_name] = widget self._tab_order.append(attribute_name) - def __getitem__(self, item: str) -> "CustomizationView": + def __getitem__(self, item: str) -> CustomizationView: return self._tab_map[item] def __iter__(self) -> Iterator[QWidget]: diff --git a/src/ert/gui/tools/plot/customize/limits_customization_view.py b/src/ert/gui/tools/plot/customize/limits_customization_view.py index dd8fb620cfb..fec3e58cccf 100644 --- a/src/ert/gui/tools/plot/customize/limits_customization_view.py +++ b/src/ert/gui/tools/plot/customize/limits_customization_view.py @@ -2,7 +2,7 @@ from copy import copy from datetime import date -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import TYPE_CHECKING, Any, ClassVar from qtpy.QtGui import QDoubleValidator, QIntValidator from qtpy.QtWidgets import QLabel, QStackedWidget, QWidget @@ -19,17 +19,17 @@ class StackedInput(QStackedWidget): def __init__(self) -> None: QStackedWidget.__init__(self) - self._inputs: dict[Optional[str], QWidget] = {} - self._index_map: dict[Optional[str], int] = {} + self._inputs: dict[str | None, QWidget] = {} + self._index_map: dict[str | None, int] = {} self.addInput(PlotContext.UNKNOWN_AXIS, QLabel("Fixed")) - self._current_name: Optional[str] = PlotContext.UNKNOWN_AXIS + self._current_name: str | None = PlotContext.UNKNOWN_AXIS - def addInput(self, name: Optional[str], widget: QWidget) -> None: + def addInput(self, name: str | None, widget: QWidget) -> None: index = self.addWidget(widget) self._inputs[name] = widget self._index_map[name] = index - def switchToInput(self, name: Optional[str]) -> None: + def switchToInput(self, name: str | None) -> None: index_for_name = self._index_map[name] self.setCurrentIndex(index_for_name) self._current_name = name @@ -65,8 +65,8 @@ def __init__(self) -> None: @staticmethod def createDoubleLineEdit( - minimum: Optional[float] = None, - maximum: Optional[float] = None, + minimum: float | None = None, + maximum: float | None = None, placeholder: str = "", ) -> ClearableLineEdit: line_edit = ClearableLineEdit(placeholder=placeholder) @@ -83,8 +83,8 @@ def createDoubleLineEdit( @staticmethod def createIntegerLineEdit( - minimum: Optional[int] = None, - maximum: Optional[int] = None, + minimum: int | None = None, + maximum: int | None = None, placeholder: str = "", ) -> ClearableLineEdit: line_edit = ClearableLineEdit(placeholder=placeholder) @@ -99,32 +99,32 @@ def createIntegerLineEdit( line_edit.setValidator(validator) return line_edit - def setValue(self, axis_name: Optional[str], value: Any) -> None: - _input = self._inputs[axis_name] + def setValue(self, axis_name: str | None, value: Any) -> None: + input_ = self._inputs[axis_name] if axis_name in LimitsStack.NUMBER_AXIS: if value is None: - _input.setText("") + input_.setText("") else: - _input.setText(str(value)) + input_.setText(str(value)) elif axis_name == PlotContext.DATE_AXIS: - _input.setDate(value) + input_.setDate(value) - def getValue(self, axis_name: Optional[str]) -> Optional[float | int | date]: - _input = self._inputs[axis_name] + def getValue(self, axis_name: str | None) -> float | int | date | None: + input_ = self._inputs[axis_name] result = None if axis_name in LimitsStack.FLOAT_AXIS: try: - result = float(_input.text()) + result = float(input_.text()) except ValueError: result = None elif axis_name in LimitsStack.INT_AXIS: try: - result = int(_input.text()) + result = int(input_.text()) except ValueError: result = None elif axis_name == PlotContext.DATE_AXIS: - result = _input.date() + result = input_.date() return result @@ -134,11 +134,11 @@ def __init__(self) -> None: self._limits = PlotLimits() self._x_minimum_stack = LimitsStack() self._x_maximum_stack = LimitsStack() - self._x_current_input_name: Optional[str] = PlotContext.UNKNOWN_AXIS + self._x_current_input_name: str | None = PlotContext.UNKNOWN_AXIS self._y_minimum_stack = LimitsStack() self._y_maximum_stack = LimitsStack() - self._y_current_input_name: Optional[str] = PlotContext.UNKNOWN_AXIS + self._y_current_input_name: str | None = PlotContext.UNKNOWN_AXIS @property def x_minimum_stack(self) -> LimitsStack: @@ -204,9 +204,7 @@ def _updateLimits(self) -> None: maximum = self._y_maximum_stack.getValue(self._y_current_input_name) self._updateLimit(self._y_current_input_name, minimum, maximum) - def _updateLimit( - self, axis_name: Optional[str], minimum: Any, maximum: Any - ) -> None: + def _updateLimit(self, axis_name: str | None, minimum: Any, maximum: Any) -> None: if axis_name == PlotContext.COUNT_AXIS: self._limits.count_limits = minimum, maximum elif axis_name == PlotContext.DENSITY_AXIS: @@ -218,13 +216,13 @@ def _updateLimit( elif axis_name == PlotContext.VALUE_AXIS: self._limits.value_limits = minimum, maximum - def switchInputOnX(self, axis_type: Optional[str]) -> None: + def switchInputOnX(self, axis_type: str | None) -> None: self._x_current_input_name = axis_type self._updateWidgets() self._x_minimum_stack.switchToInput(axis_type) self._x_maximum_stack.switchToInput(axis_type) - def switchInputOnY(self, axis_type: Optional[str]) -> None: + def switchInputOnY(self, axis_type: str | None) -> None: self._y_current_input_name = axis_type self._updateWidgets() self._y_minimum_stack.switchToInput(axis_type) @@ -246,9 +244,7 @@ def __init__(self) -> None: self.addRow("Minimum", limits_widget.y_minimum_stack) self.addRow("Maximum", limits_widget.y_maximum_stack) - def setAxisTypes( - self, x_axis_type: Optional[str], y_axis_type: Optional[str] - ) -> None: + def setAxisTypes(self, x_axis_type: str | None, y_axis_type: str | None) -> None: self._limits_widget.switchInputOnX(x_axis_type) self._limits_widget.switchInputOnY(y_axis_type) diff --git a/src/ert/gui/tools/plot/customize/statistics_customization_view.py b/src/ert/gui/tools/plot/customize/statistics_customization_view.py index 5ad1a988737..9fbd4c991db 100644 --- a/src/ert/gui/tools/plot/customize/statistics_customization_view.py +++ b/src/ert/gui/tools/plot/customize/statistics_customization_view.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from qtpy.QtWidgets import QComboBox, QHBoxLayout @@ -123,8 +123,8 @@ def presetSelected(self, index: int) -> None: def updateStyle( self, attribute_name: str, - line_style: Optional[str], - marker_style: Optional[str], + line_style: str | None, + marker_style: str | None, ) -> None: style = getattr(self, attribute_name) style.line_style = line_style diff --git a/src/ert/gui/tools/plot/customize/style_chooser.py b/src/ert/gui/tools/plot/customize/style_chooser.py index 4e47640fcf0..46bd7a0fa9b 100644 --- a/src/ert/gui/tools/plot/customize/style_chooser.py +++ b/src/ert/gui/tools/plot/customize/style_chooser.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator from qtpy.QtWidgets import ( QComboBox, diff --git a/src/ert/gui/tools/plot/data_type_keys_widget.py b/src/ert/gui/tools/plot/data_type_keys_widget.py index 755ca472968..31e9dc26671 100644 --- a/src/ert/gui/tools/plot/data_type_keys_widget.py +++ b/src/ert/gui/tools/plot/data_type_keys_widget.py @@ -126,8 +126,8 @@ def getSelectedItem(self) -> PlotApiKeyDefinition | None: def selectDefault(self) -> None: self.data_type_keys_widget.setCurrentIndex(self.filter_model.index(0, 0)) - def setSearchString(self, _filter: str | None) -> None: - self.filter_model.setFilterFixedString(_filter) + def setSearchString(self, filter_: str | None) -> None: + self.filter_model.setFilterFixedString(filter_) def showFilterPopup(self) -> None: self.__filter_popup.show() diff --git a/src/ert/gui/tools/plot/data_type_proxy_model.py b/src/ert/gui/tools/plot/data_type_proxy_model.py index 30e34112e2b..09a18b77d76 100644 --- a/src/ert/gui/tools/plot/data_type_proxy_model.py +++ b/src/ert/gui/tools/plot/data_type_proxy_model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from qtpy.QtCore import QModelIndex, QObject, QSortFilterProxyModel, Qt @@ -9,7 +9,7 @@ class DataTypeProxyModel(QSortFilterProxyModel): - def __init__(self, parent: Optional[QObject], model: DataTypeKeysListModel) -> None: + def __init__(self, parent: QObject | None, model: DataTypeKeysListModel) -> None: QSortFilterProxyModel.__init__(self, parent) self.__show_summary_keys = True diff --git a/src/ert/gui/tools/plot/plot_api.py b/src/ert/gui/tools/plot/plot_api.py index ac96f838f0c..b76ca97051b 100644 --- a/src/ert/gui/tools/plot/plot_api.py +++ b/src/ert/gui/tools/plot/plot_api.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from itertools import combinations as combi from json.decoder import JSONDecodeError -from typing import Any, Dict, List, NamedTuple, Optional +from typing import Any, NamedTuple from urllib.parse import quote import httpx @@ -25,35 +25,31 @@ class EnsembleObject: experiment_name: str -PlotApiKeyDefinition = NamedTuple( - "PlotApiKeyDefinition", - [ - ("key", str), - ("index_type", Optional[str]), - ("observations", bool), - ("dimensionality", int), - ("metadata", Dict[Any, Any]), - ("log_scale", bool), - ], -) +class PlotApiKeyDefinition(NamedTuple): + key: str + index_type: str | None + observations: bool + dimensionality: int + metadata: dict[Any, Any] + log_scale: bool class PlotApi: def __init__(self) -> None: - self._all_ensembles: Optional[List[EnsembleObject]] = None + self._all_ensembles: list[EnsembleObject] | None = None self._timeout = 120 @staticmethod def escape(s: str) -> str: return quote(quote(s, safe="")) - def _get_ensemble_by_id(self, id: str) -> Optional[EnsembleObject]: + def _get_ensemble_by_id(self, id: str) -> EnsembleObject | None: for ensemble in self.get_all_ensembles(): if ensemble.id == id: return ensemble return None - def get_all_ensembles(self) -> List[EnsembleObject]: + def get_all_ensembles(self) -> list[EnsembleObject]: if self._all_ensembles is not None: return self._all_ensembles @@ -97,7 +93,7 @@ def _check_response(response: httpx._models.Response) -> None: f"{response.text} from url: {response.url}." ) - def all_data_type_keys(self) -> List[PlotApiKeyDefinition]: + def all_data_type_keys(self) -> list[PlotApiKeyDefinition]: """Returns a list of all the keys except observation keys. The keys are a unique set of all keys in the ensembles @@ -105,7 +101,7 @@ def all_data_type_keys(self) -> List[PlotApiKeyDefinition]: For each key a dict is returned with info about the key""" - all_keys: Dict[str, PlotApiKeyDefinition] = {} + all_keys: dict[str, PlotApiKeyDefinition] = {} with StorageService.session() as client: response = client.get("/experiments", timeout=self._timeout) @@ -189,7 +185,7 @@ def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame: except ValueError: return df - def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFrame: + def observations_for_key(self, ensemble_ids: list[str], key: str) -> pd.DataFrame: """Returns a pandas DataFrame with the datapoints for a given observation key for a given ensembles. The row index is the realization number, and the column index is a multi-index with (obs_key, index/date, obs_index), where index/date is @@ -242,7 +238,7 @@ def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFram return all_observations.T - def history_data(self, key: str, ensemble_ids: Optional[List[str]]) -> pd.DataFrame: + def history_data(self, key: str, ensemble_ids: list[str] | None) -> pd.DataFrame: """Returns a pandas DataFrame with the data points for the history for a given data key, if any. The row index is the index/date and the column index is the key.""" diff --git a/src/ert/gui/tools/plot/plot_ensemble_selection_widget.py b/src/ert/gui/tools/plot/plot_ensemble_selection_widget.py index 9da4b6e7bd7..e72147d93ff 100644 --- a/src/ert/gui/tools/plot/plot_ensemble_selection_widget.py +++ b/src/ert/gui/tools/plot/plot_ensemble_selection_widget.py @@ -1,4 +1,5 @@ -from typing import Any, Iterator, List, Optional +from collections.abc import Iterator +from typing import Any from qtpy.QtCore import QModelIndex, QSize, Qt, Signal from qtpy.QtGui import ( @@ -27,7 +28,7 @@ class EnsembleSelectionWidget(QWidget): ensembleSelectionChanged = Signal() - def __init__(self, ensembles: List[EnsembleObject]): + def __init__(self, ensembles: list[EnsembleObject]): QWidget.__init__(self) self.__dndlist = EnsembleSelectListWidget(ensembles[::-1]) @@ -40,7 +41,7 @@ def __init__(self, ensembles: List[EnsembleObject]): self.ensembleSelectionChanged.emit ) - def get_selected_ensembles(self) -> List[EnsembleObject]: + def get_selected_ensembles(self) -> list[EnsembleObject]: return self.__dndlist.get_checked_ensembles() @@ -49,7 +50,7 @@ class EnsembleSelectListWidget(QListWidget): MAXIMUM_SELECTED = 5 MINIMUM_SELECTED = 1 - def __init__(self, ensembles: List[EnsembleObject]): + def __init__(self, ensembles: list[EnsembleObject]): super().__init__() self._ensemble_count = 0 self.setObjectName("ensemble_selector") @@ -70,7 +71,7 @@ def __init__(self, ensembles: List[EnsembleObject]): self.setItemDelegate(CustomItemDelegate()) self.itemClicked.connect(self.slot_toggle_plot) - def get_checked_ensembles(self) -> List[EnsembleObject]: + def get_checked_ensembles(self) -> list[EnsembleObject]: def _iter() -> Iterator[EnsembleObject]: for index in range(self._ensemble_count): item = self.item(index) @@ -80,14 +81,14 @@ def _iter() -> Iterator[EnsembleObject]: return list(_iter()) - def mouseMoveEvent(self, e: Optional[QMouseEvent]) -> None: + def mouseMoveEvent(self, e: QMouseEvent | None) -> None: super().mouseMoveEvent(e) if e is not None and self.itemAt(e.pos()): self.setCursor(QCursor(Qt.CursorShape.PointingHandCursor)) else: self.setCursor(QCursor(Qt.CursorShape.ArrowCursor)) - def dropEvent(self, event: Optional[QDropEvent]) -> None: + def dropEvent(self, event: QDropEvent | None) -> None: super().dropEvent(event) self.ensembleSelectionListChanged.emit() @@ -113,7 +114,7 @@ def sizeHint(self, option: Any, index: Any) -> QSize: def paint( self, - painter: Optional[QPainter], + painter: QPainter | None, option: QStyleOptionViewItem, index: QModelIndex, ) -> None: diff --git a/src/ert/gui/tools/plot/plot_widget.py b/src/ert/gui/tools/plot/plot_widget.py index 92093a59ae1..a94e5f60633 100644 --- a/src/ert/gui/tools/plot/plot_widget.py +++ b/src/ert/gui/tools/plot/plot_widget.py @@ -1,6 +1,6 @@ import sys import traceback -from typing import TYPE_CHECKING, Dict, Optional, Union +from typing import TYPE_CHECKING, Union import numpy as np import numpy.typing as npt @@ -12,13 +12,7 @@ from matplotlib.figure import Figure from qtpy.QtCore import QStringListModel, Qt, Signal, Slot from qtpy.QtGui import QIcon -from qtpy.QtWidgets import ( - QAction, - QComboBox, - QVBoxLayout, - QWidget, - QWidgetAction, -) +from qtpy.QtWidgets import QAction, QComboBox, QVBoxLayout, QWidget, QWidgetAction from .plot_api import EnsembleObject @@ -40,7 +34,7 @@ class CustomNavigationToolbar(NavigationToolbar2QT): def __init__( self, canvas: FigureCanvas, - parent: Optional[QWidget], + parent: QWidget | None, coordinates: bool = True, ) -> None: super().__init__(canvas, parent, coordinates) # type: ignore @@ -104,7 +98,7 @@ def __init__( "CrossEnsembleStatisticsPlot", "StdDevPlot", ], - parent: Optional[QWidget] = None, + parent: QWidget | None = None, ) -> None: QWidget.__init__(self, parent) @@ -143,9 +137,9 @@ def name(self) -> str: def updatePlot( self, plot_context: "PlotContext", - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], observations: pd.DataFrame, - std_dev_images: Dict[str, npt.NDArray[np.float32]], + std_dev_images: dict[str, npt.NDArray[np.float32]], ) -> None: self.resetPlot() try: diff --git a/src/ert/gui/tools/plot/plot_window.py b/src/ert/gui/tools/plot/plot_window.py index bcf77faa067..b3b8532e4c6 100644 --- a/src/ert/gui/tools/plot/plot_window.py +++ b/src/ert/gui/tools/plot/plot_window.py @@ -1,6 +1,6 @@ import logging import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -92,7 +92,7 @@ def open_error_dialog(title: str, content: str) -> None: class PlotWindow(QMainWindow): - def __init__(self, config_file: str, parent: Optional[QWidget]): + def __init__(self, config_file: str, parent: QWidget | None): QMainWindow.__init__(self, parent) t = time.perf_counter() @@ -127,7 +127,7 @@ def __init__(self, config_file: str, parent: Optional[QWidget]): self.setCentralWidget(central_widget) - self._plot_widgets: List[PlotWidget] = [] + self._plot_widgets: list[PlotWidget] = [] self.addPlotWidget(ENSEMBLE, EnsemblePlot()) self.addPlotWidget(STATISTICS, StatisticsPlot()) @@ -137,7 +137,7 @@ def __init__(self, config_file: str, parent: Optional[QWidget]): self.addPlotWidget(CROSS_ENSEMBLE_STATISTICS, CrossEnsembleStatisticsPlot()) self.addPlotWidget(STD_DEV, StdDevPlot()) self._central_tab.currentChanged.connect(self.currentTabChanged) - self._prev_tab_widget: Optional[QWidget] = None + self._prev_tab_widget: QWidget | None = None QApplication.setOverrideCursor(Qt.CursorShape.WaitCursor) try: @@ -169,10 +169,10 @@ def currentTabChanged(self, index: Any) -> None: self.updatePlot() @Slot(int) - def layerIndexChanged(self, index: Optional[int]) -> None: + def layerIndexChanged(self, index: int | None) -> None: self.updatePlot(index) - def updatePlot(self, layer: Optional[int] = None) -> None: + def updatePlot(self, layer: int | None = None) -> None: key_def = self.getSelectedKey() if key_def is None: return @@ -185,7 +185,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None: selected_ensembles = ( self._ensemble_selection_widget.get_selected_ensembles() ) - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame] = {} + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame] = {} for ensemble in selected_ensembles: try: ensemble_to_data_map[ensemble] = self._api.data_for_key( @@ -205,7 +205,7 @@ def updatePlot(self, layer: Optional[int] = None) -> None: logger.exception(f"plot api request failed: {e}") open_error_dialog("Request failed", f"{e}") - std_dev_images: Dict[str, npt.NDArray[np.float32]] = {} + std_dev_images: dict[str, npt.NDArray[np.float32]] = {} if "FIELD" in key_def.metadata["data_origin"]: plot_widget.showLayerWidget.emit(True) @@ -281,21 +281,19 @@ def _updateCustomizer( self._plot_customizer.setAxisTypes(x_axis_type, y_axis_type) - def getSelectedKey(self) -> Optional[PlotApiKeyDefinition]: + def getSelectedKey(self) -> PlotApiKeyDefinition | None: return self._data_type_keys_widget.getSelectedItem() def addPlotWidget( self, name: str, - plotter: Union[ - EnsemblePlot, - StatisticsPlot, - HistogramPlot, - GaussianKDEPlot, - DistributionPlot, - CrossEnsembleStatisticsPlot, - StdDevPlot, - ], + plotter: EnsemblePlot + | StatisticsPlot + | HistogramPlot + | GaussianKDEPlot + | DistributionPlot + | CrossEnsembleStatisticsPlot + | StdDevPlot, enabled: bool = True, ) -> None: plot_widget = PlotWidget(name, plotter) diff --git a/src/ert/gui/tools/plot/plottery/plot_config.py b/src/ert/gui/tools/plot/plottery/plot_config.py index 52e1ad2d19b..7a3eb971b3e 100644 --- a/src/ert/gui/tools/plot/plottery/plot_config.py +++ b/src/ert/gui/tools/plot/plottery/plot_config.py @@ -2,7 +2,7 @@ import itertools from copy import copy -from typing import Any, List, Optional, Tuple +from typing import Any from .plot_limits import PlotLimits from .plot_style import PlotStyle @@ -14,10 +14,10 @@ class PlotConfig: def __init__( self, - plot_settings: Optional[dict[str, Any]] = None, - title: Optional[str] = "Unnamed", - x_label: Optional[str] = None, - y_label: Optional[str] = None, + plot_settings: dict[str, Any] | None = None, + title: str | None = "Unnamed", + x_label: str | None = None, + y_label: str | None = None, ): self._title = title self._plot_settings = plot_settings @@ -72,7 +72,7 @@ def __init__( name="Distribution lines", line_style="-", alpha=0.25, width=1.0 ) self._distribution_line_style.setEnabled(False) - self._current_color: Optional[Tuple[str, float]] = None + self._current_color: tuple[str, float] | None = None self._legend_enabled = True self._grid_enabled = True @@ -88,22 +88,22 @@ def __init__( self._std_dev_factor = 1 # sigma 1 is default std dev - def currentColor(self) -> Tuple[str, float]: + def currentColor(self) -> tuple[str, float]: if self._current_color is None: return self.nextColor() return self._current_color - def nextColor(self) -> Tuple[str, float]: + def nextColor(self) -> tuple[str, float]: color = next(self._line_color_cycle) self._current_color = color return self._current_color - def setLineColorCycle(self, color_list: List[Tuple[str, float]]) -> None: + def setLineColorCycle(self, color_list: list[tuple[str, float]]) -> None: self._line_color_cycle_colors = color_list self._line_color_cycle = itertools.cycle(color_list) - def lineColorCycle(self) -> List[Tuple[str, float]]: + def lineColorCycle(self) -> list[tuple[str, float]]: return self._line_color_cycle_colors def addLegendItem(self, label: str, item: Any) -> None: @@ -113,7 +113,7 @@ def addLegendItem(self, label: str, item: Any) -> None: def title(self) -> str: return self._title if self._title is not None else "Unnamed" - def setTitle(self, title: Optional[str]) -> None: + def setTitle(self, title: str | None) -> None: self._title = title def isUnnamed(self) -> bool: @@ -125,7 +125,7 @@ def defaultStyle(self) -> PlotStyle: style.color, style.alpha = self.currentColor() return style - def observationsColor(self) -> Tuple[str, float]: + def observationsColor(self) -> tuple[str, float]: assert self._observs_style.color return (self._observs_style.color, self._observs_style.alpha) @@ -156,16 +156,16 @@ def distributionLineStyle(self) -> PlotStyle: style.copyStyleFrom(self._distribution_line_style) return style - def xLabel(self) -> Optional[str]: + def xLabel(self) -> str | None: return self._x_label - def yLabel(self) -> Optional[str]: + def yLabel(self) -> str | None: return self._y_label - def legendItems(self) -> List[Any]: + def legendItems(self) -> list[Any]: return self._legend_items - def legendLabels(self) -> List[str]: + def legendLabels(self) -> list[str]: return self._legend_labels def setXLabel(self, label: str) -> None: @@ -230,7 +230,7 @@ def setHistoryStyle(self, style: PlotStyle) -> None: self._history_style.width = style.width self._history_style.size = style.size - def setObservationsColor(self, color_tuple: Tuple[str, float]) -> None: + def setObservationsColor(self, color_tuple: tuple[str, float]) -> None: self._observs_style.color, self._observs_style.alpha = color_tuple def setObservationsStyle(self, style: PlotStyle) -> None: @@ -253,7 +253,7 @@ def limits(self) -> PlotLimits: def limits(self, value: PlotLimits) -> None: self._limits = copy(value) - def copyConfigFrom(self, other: "PlotConfig") -> None: + def copyConfigFrom(self, other: PlotConfig) -> None: self._default_style.copyStyleFrom(other._default_style, copy_enabled_state=True) self._history_style.copyStyleFrom(other._history_style, copy_enabled_state=True) self._histogram_style.copyStyleFrom( diff --git a/src/ert/gui/tools/plot/plottery/plot_config_history.py b/src/ert/gui/tools/plot/plottery/plot_config_history.py index e8346ad62b4..379b1431dc5 100644 --- a/src/ert/gui/tools/plot/plottery/plot_config_history.py +++ b/src/ert/gui/tools/plot/plottery/plot_config_history.py @@ -1,5 +1,3 @@ -from typing import List - from .plot_config import PlotConfig @@ -11,8 +9,8 @@ def __init__(self, name: str, initial: PlotConfig) -> None: super().__init__() self._name = name self._initial = PlotConfig.createCopy(initial) - self._undo_history: List[PlotConfig] = [] - self._redo_history: List[PlotConfig] = [] + self._undo_history: list[PlotConfig] = [] + self._redo_history: list[PlotConfig] = [] self._current = PlotConfig.createCopy(self._initial) def isUndoPossible(self) -> bool: diff --git a/src/ert/gui/tools/plot/plottery/plot_context.py b/src/ert/gui/tools/plot/plottery/plot_context.py index 273eaeb4f69..ab9bb5cae63 100644 --- a/src/ert/gui/tools/plot/plottery/plot_context.py +++ b/src/ert/gui/tools/plot/plottery/plot_context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, List, Optional +from typing import TYPE_CHECKING, ClassVar if TYPE_CHECKING: from pandas import DataFrame @@ -29,26 +29,26 @@ class PlotContext: def __init__( self, plot_config: PlotConfig, - ensembles: List[EnsembleObject], + ensembles: list[EnsembleObject], key: str, - layer: Optional[int] = None, + layer: int | None = None, ) -> None: super().__init__() self._key = key self._ensembles = ensembles self._plot_config = plot_config - self.history_data: Optional[DataFrame] = None + self.history_data: DataFrame | None = None self._log_scale = False - self._layer: Optional[int] = layer + self._layer: int | None = layer self._date_support_active = True - self._x_axis: Optional[str] = None - self._y_axis: Optional[str] = None + self._x_axis: str | None = None + self._y_axis: str | None = None - def plotConfig(self) -> "PlotConfig": + def plotConfig(self) -> PlotConfig: return self._plot_config - def ensembles(self) -> List[EnsembleObject]: + def ensembles(self) -> list[EnsembleObject]: return self._ensembles def key(self) -> str: @@ -61,11 +61,11 @@ def isDateSupportActive(self) -> bool: return self._date_support_active @property - def layer(self) -> Optional[int]: + def layer(self) -> int | None: return self._layer @property - def x_axis(self) -> Optional[str]: + def x_axis(self) -> str | None: return self._x_axis @x_axis.setter @@ -77,7 +77,7 @@ def x_axis(self, value: str) -> None: self._x_axis = value @property - def y_axis(self) -> Optional[str]: + def y_axis(self) -> str | None: return self._y_axis @y_axis.setter diff --git a/src/ert/gui/tools/plot/plottery/plot_limits.py b/src/ert/gui/tools/plot/plottery/plot_limits.py index 062a5bc78a9..33c0891cb23 100644 --- a/src/ert/gui/tools/plot/plottery/plot_limits.py +++ b/src/ert/gui/tools/plot/plottery/plot_limits.py @@ -1,94 +1,94 @@ from dataclasses import dataclass from datetime import date -from typing import Optional, Tuple, Union +from typing import TypeAlias -Num = Union[float, int] +Num: TypeAlias = float | int @dataclass class PlotLimits: - value_limits: Tuple[Optional[Num], Optional[Num]] = (None, None) - index_limits: Tuple[Optional[int], Optional[int]] = (None, None) - count_limits: Tuple[Optional[int], Optional[int]] = (None, None) - density_limits: Tuple[Optional[Num], Optional[Num]] = (None, None) - date_limits: Tuple[Optional[date], Optional[date]] = (None, None) + value_limits: tuple[Num | None, Num | None] = (None, None) + index_limits: tuple[int | None, int | None] = (None, None) + count_limits: tuple[int | None, int | None] = (None, None) + density_limits: tuple[Num | None, Num | None] = (None, None) + date_limits: tuple[date | None, date | None] = (None, None) @property - def value_minimum(self) -> Optional[Num]: + def value_minimum(self) -> Num | None: return self.value_limits[0] @value_minimum.setter - def value_minimum(self, value: Optional[Num]) -> None: + def value_minimum(self, value: Num | None) -> None: self.value_limits = (value, self.value_limits[1]) @property - def value_maximum(self) -> Optional[Num]: + def value_maximum(self) -> Num | None: return self.value_limits[1] @value_maximum.setter - def value_maximum(self, value: Optional[Num]) -> None: + def value_maximum(self, value: Num | None) -> None: self.value_limits = (self.value_limits[0], value) @property - def count_minimum(self) -> Optional[int]: + def count_minimum(self) -> int | None: return self.count_limits[0] @count_minimum.setter - def count_minimum(self, value: Optional[int]) -> None: + def count_minimum(self, value: int | None) -> None: self.count_limits = (value, self.count_limits[1]) @property - def count_maximum(self) -> Optional[int]: + def count_maximum(self) -> int | None: return self.count_limits[1] @count_maximum.setter - def count_maximum(self, value: Optional[int]) -> None: + def count_maximum(self, value: int | None) -> None: self.count_limits = (self.count_limits[0], value) @property - def index_minimum(self) -> Optional[int]: + def index_minimum(self) -> int | None: return self.index_limits[0] @index_minimum.setter - def index_minimum(self, value: Optional[int]) -> None: + def index_minimum(self, value: int | None) -> None: self.index_limits = (value, self.index_limits[1]) @property - def index_maximum(self) -> Optional[int]: + def index_maximum(self) -> int | None: return self.index_limits[1] @index_maximum.setter - def index_maximum(self, value: Optional[int]) -> None: + def index_maximum(self, value: int | None) -> None: self.index_limits = (self.index_limits[0], value) @property - def density_minimum(self) -> Optional[Num]: + def density_minimum(self) -> Num | None: return self.density_limits[0] @density_minimum.setter - def density_minimum(self, value: Optional[Num]) -> None: + def density_minimum(self, value: Num | None) -> None: self.density_limits = (value, self.density_limits[1]) @property - def density_maximum(self) -> Optional[Num]: + def density_maximum(self) -> Num | None: return self.density_limits[1] @density_maximum.setter - def density_maximum(self, value: Optional[Num]) -> None: + def density_maximum(self, value: Num | None) -> None: self.density_limits = (self.density_limits[0], value) @property - def date_minimum(self) -> Optional[date]: + def date_minimum(self) -> date | None: return self.date_limits[0] @date_minimum.setter - def date_minimum(self, value: Optional[date]) -> None: + def date_minimum(self, value: date | None) -> None: self.date_limits = (value, self.date_limits[1]) @property - def date_maximum(self) -> Optional[date]: + def date_maximum(self) -> date | None: return self.date_limits[1] @date_maximum.setter - def date_maximum(self, value: Optional[date]) -> None: + def date_maximum(self, value: date | None) -> None: self.date_limits = (self.date_limits[0], value) diff --git a/src/ert/gui/tools/plot/plottery/plot_style.py b/src/ert/gui/tools/plot/plottery/plot_style.py index 009e510ae22..a3ae6ecb2e3 100644 --- a/src/ert/gui/tools/plot/plottery/plot_style.py +++ b/src/ert/gui/tools/plot/plottery/plot_style.py @@ -1,11 +1,8 @@ -from typing import Optional - - class PlotStyle: def __init__( self, name: str, - color: Optional[str] = "#000000", + color: str | None = "#000000", alpha: float = 1.0, line_style: str = "-", marker: str = "", diff --git a/src/ert/gui/tools/plot/plottery/plots/cesp.py b/src/ert/gui/tools/plot/plottery/plots/cesp.py index 9c735e3b063..1c12db2c073 100644 --- a/src/ert/gui/tools/plot/plottery/plots/cesp.py +++ b/src/ert/gui/tools/plot/plottery/plots/cesp.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -21,21 +21,17 @@ from ert.gui.tools.plot.plottery import PlotConfig, PlotContext -CcsData = TypedDict( - "CcsData", - { - "index": List[int], - "mean": Dict[int, float], - "min": Dict[int, float], - "max": Dict[int, float], - "p10": Dict[int, float], - "p33": Dict[int, float], - "p50": Dict[int, float], - "p67": Dict[int, float], - "p90": Dict[int, float], - "std": Dict[int, float], - }, -) +class CcsData(TypedDict): + index: list[int] + mean: dict[int, float] + min: dict[int, float] + max: dict[int, float] + p10: dict[int, float] + p33: dict[int, float] + p50: dict[int, float] + p67: dict[int, float] + p90: dict[int, float] + std: dict[int, float] class CrossEnsembleStatisticsPlot: @@ -46,9 +42,9 @@ def __init__(self) -> None: def plot( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], observation_data: pd.DataFrame, - std_dev_images: Dict[str, npt.NDArray[np.float32]], + std_dev_images: dict[str, npt.NDArray[np.float32]], ) -> None: plotCrossEnsembleStatistics( figure, plot_context, ensemble_to_data_map, observation_data @@ -57,8 +53,8 @@ def plot( def plotCrossEnsembleStatistics( figure: Figure, - plot_context: "PlotContext", - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + plot_context: PlotContext, + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], _observation_data: DataFrame, ) -> None: config = plot_context.plotConfig() @@ -132,7 +128,7 @@ def plotCrossEnsembleStatistics( ) -def _addStatisticsLegends(plot_config: "PlotConfig") -> None: +def _addStatisticsLegends(plot_config: PlotConfig) -> None: _addStatisticsLegend(plot_config, "mean") _addStatisticsLegend(plot_config, "p50") _addStatisticsLegend(plot_config, "min-max", 0.2) @@ -142,7 +138,7 @@ def _addStatisticsLegends(plot_config: "PlotConfig") -> None: def _addStatisticsLegend( - plot_config: "PlotConfig", style_name: str, alpha_multiplier: float = 1.0 + plot_config: PlotConfig, style_name: str, alpha_multiplier: float = 1.0 ) -> None: style = plot_config.getStatisticsStyle(style_name) if style.isVisible(): @@ -178,7 +174,7 @@ def _assertNumeric(data: pd.DataFrame) -> pd.Series: def _plotCrossEnsembleStatistics( - axes: "Axes", plot_config: "PlotConfig", data: CcsData, index: int + axes: Axes, plot_config: PlotConfig, data: CcsData, index: int ) -> None: axes.set_xlabel(plot_config.xLabel()) # type: ignore axes.set_ylabel(plot_config.yLabel()) # type: ignore @@ -293,8 +289,8 @@ def _plotCrossEnsembleStatistics( def _plotConnectionLines( - axes: "Axes", - plot_config: "PlotConfig", + axes: Axes, + plot_config: PlotConfig, ccs: CcsData, ) -> None: line_style = plot_config.distributionLineStyle() diff --git a/src/ert/gui/tools/plot/plottery/plots/distribution.py b/src/ert/gui/tools/plot/plottery/plots/distribution.py index c106fdf02c9..42cd727f73a 100644 --- a/src/ert/gui/tools/plot/plottery/plots/distribution.py +++ b/src/ert/gui/tools/plot/plottery/plots/distribution.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -25,9 +25,9 @@ def __init__(self) -> None: def plot( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], observation_data: pd.DataFrame, - std_dev_images: Dict[str, npt.NDArray[np.float32]], + std_dev_images: dict[str, npt.NDArray[np.float32]], ) -> None: plotDistribution(figure, plot_context, ensemble_to_data_map, observation_data) @@ -35,7 +35,7 @@ def plot( def plotDistribution( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], _observation_data: pd.DataFrame, ) -> None: config = plot_context.plotConfig() @@ -49,7 +49,7 @@ def plotDistribution( axes.set_yscale("log") ensemble_list = plot_context.ensembles() - ensemble_indexes: List[int] = [] + ensemble_indexes: list[int] = [] previous_data = None for ensemble_index, (ensemble, data) in enumerate(ensemble_to_data_map.items()): ensemble_indexes.append(ensemble_index) @@ -85,12 +85,12 @@ def plotDistribution( def _plotDistribution( - axes: "Axes", - plot_config: "PlotConfig", + axes: Axes, + plot_config: PlotConfig, data: pd.DataFrame, label: str, index: int, - previous_data: Optional[pd.DataFrame], + previous_data: pd.DataFrame | None, ) -> None: data = pd.Series(dtype="float64") if data.empty else data[0] diff --git a/src/ert/gui/tools/plot/plottery/plots/ensemble.py b/src/ert/gui/tools/plot/plottery/plots/ensemble.py index 5b0a7b1a07f..400908062fd 100644 --- a/src/ert/gui/tools/plot/plottery/plots/ensemble.py +++ b/src/ert/gui/tools/plot/plottery/plots/ensemble.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -28,9 +28,9 @@ def plot( self, figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], observation_data: pd.DataFrame, - std_dev_images: Dict[str, npt.NDArray[np.float32]], + std_dev_images: dict[str, npt.NDArray[np.float32]], ) -> None: config = plot_context.plotConfig() axes = figure.add_subplot(111) @@ -74,7 +74,7 @@ def _plotLines( plot_config: PlotConfig, data: pd.DataFrame, ensemble_label: str, - draw_style: Optional[str] = None, + draw_style: str | None = None, ) -> None: style = plot_config.defaultStyle() diff --git a/src/ert/gui/tools/plot/plottery/plots/gaussian_kde.py b/src/ert/gui/tools/plot/plottery/plots/gaussian_kde.py index 230bfef51b2..40736e24d3e 100644 --- a/src/ert/gui/tools/plot/plottery/plots/gaussian_kde.py +++ b/src/ert/gui/tools/plot/plottery/plots/gaussian_kde.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -26,9 +26,9 @@ def __init__(self) -> None: def plot( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], observation_data: pd.DataFrame, - std_dev_images: Dict[str, npt.NDArray[np.float32]], + std_dev_images: dict[str, npt.NDArray[np.float32]], ) -> None: plotGaussianKDE(figure, plot_context, ensemble_to_data_map, observation_data) @@ -36,7 +36,7 @@ def plot( def plotGaussianKDE( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], _observation_data: Any, ) -> None: config = plot_context.plotConfig() diff --git a/src/ert/gui/tools/plot/plottery/plots/histogram.py b/src/ert/gui/tools/plot/plottery/plots/histogram.py index bdd33812665..5361b8dfb16 100644 --- a/src/ert/gui/tools/plot/plottery/plots/histogram.py +++ b/src/ert/gui/tools/plot/plottery/plots/histogram.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Sequence from math import ceil, floor, log10, sqrt -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -27,9 +28,9 @@ def __init__(self) -> None: def plot( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], observation_data: pd.DataFrame, - std_dev_images: Dict[str, npt.NDArray[np.float32]], + std_dev_images: dict[str, npt.NDArray[np.float32]], ) -> None: plotHistogram(figure, plot_context, ensemble_to_data_map, observation_data) @@ -37,7 +38,7 @@ def plot( def plotHistogram( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], observation_data: pd.DataFrame, ) -> None: config = plot_context.plotConfig() @@ -106,7 +107,7 @@ def plotHistogram( maximum = current_max if maximum is None else max(maximum, current_max) # type: ignore max_element_count = max(max_element_count, len(data[ensemble.name].index)) - bin_count = int(ceil(sqrt(max_element_count))) + bin_count = ceil(sqrt(max_element_count)) axes = {} for index, ensemble in enumerate(ensemble_list): @@ -169,10 +170,10 @@ def plotHistogram( def _plotCategoricalHistogram( - axes: "Axes", + axes: Axes, style: PlotStyle, data: pd.DataFrame, - categories: List[str], + categories: list[str], ) -> Rectangle: counts = data.value_counts() freq = [counts.get(category, 0) for category in categories] @@ -189,15 +190,15 @@ def _plotCategoricalHistogram( def _plotHistogram( - axes: "Axes", + axes: Axes, style: PlotStyle, data: pd.DataFrame, bin_count: int, use_log_scale: float = False, - minimum: Optional[float] = None, - maximum: Optional[float] = None, + minimum: float | None = None, + maximum: float | None = None, ) -> Rectangle: - bins: Union[Sequence[float], int] + bins: Sequence[float] | int if minimum is not None and maximum is not None: if use_log_scale: bins = _histogramLogBins(bin_count, minimum, maximum) # type: ignore @@ -225,8 +226,8 @@ def _histogramLogBins( minimum = log10(float(minimum)) maximum = log10(float(maximum)) - min_value = int(floor(minimum)) - max_value = int(ceil(maximum)) + min_value = floor(minimum) + max_value = ceil(maximum) log_bin_count = max_value - min_value diff --git a/src/ert/gui/tools/plot/plottery/plots/plot_tools.py b/src/ert/gui/tools/plot/plottery/plots/plot_tools.py index 78cacda5afb..76ec2bd059d 100644 --- a/src/ert/gui/tools/plot/plottery/plots/plot_tools.py +++ b/src/ert/gui/tools/plot/plottery/plots/plot_tools.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: from datetime import date @@ -27,11 +27,12 @@ def showLegend(axes: Axes, plot_context: PlotContext) -> None: @staticmethod def _getXAxisLimits( plot_context: PlotContext, - ) -> Optional[ - tuple[Optional[int], Optional[int]] - | tuple[Optional[float], Optional[float]] - | tuple[Optional[date], Optional[date]] - ]: + ) -> ( + tuple[int | None, int | None] + | tuple[float | None, float | None] + | tuple[date | None, date | None] + | None + ): limits = plot_context.plotConfig().limits axis_name = plot_context.x_axis @@ -51,11 +52,12 @@ def _getXAxisLimits( @staticmethod def _getYAxisLimits( plot_context: PlotContext, - ) -> Optional[ - tuple[Optional[int], Optional[int]] - | tuple[Optional[float], Optional[float]] - | tuple[Optional[date], Optional[date]] - ]: + ) -> ( + tuple[int | None, int | None] + | tuple[float | None, float | None] + | tuple[date | None, date | None] + | None + ): limits = plot_context.plotConfig().limits axis_name = plot_context.y_axis diff --git a/src/ert/gui/tools/plot/plottery/plots/statistics.py b/src/ert/gui/tools/plot/plottery/plots/statistics.py index 450b0aeaafd..67690cf0c92 100644 --- a/src/ert/gui/tools/plot/plottery/plots/statistics.py +++ b/src/ert/gui/tools/plot/plottery/plots/statistics.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING import numpy as np from matplotlib.lines import Line2D @@ -29,9 +29,9 @@ def __init__(self) -> None: def plot( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, DataFrame], - _observation_data: DataFrame, - std_dev_images: Dict[str, npt.NDArray[np.float32]], + ensemble_to_data_map: dict[EnsembleObject, DataFrame], + observation_data: DataFrame, + std_dev_images: dict[str, npt.NDArray[np.float32]], ) -> None: config = plot_context.plotConfig() axes = figure.add_subplot(111) @@ -73,7 +73,7 @@ def plot( _addStatisticsLegends(plot_config=config) - plotObservations(_observation_data, plot_context, axes) + plotObservations(observation_data, plot_context, axes) plotHistory(plot_context, axes) default_x_label = "Date" if plot_context.isDateSupportActive() else "Index" @@ -120,7 +120,7 @@ def _addStatisticsLegend( def _plotPercentiles( - axes: "Axes", plot_config: PlotConfig, data: DataFrame, ensemble_label: str + axes: Axes, plot_config: PlotConfig, data: DataFrame, ensemble_label: str ) -> None: style = plot_config.getStatisticsStyle("mean") if style.isVisible(): diff --git a/src/ert/gui/tools/plot/plottery/plots/std_dev.py b/src/ert/gui/tools/plot/plottery/plots/std_dev.py index 1c39f1e2233..03aa2a77c5e 100644 --- a/src/ert/gui/tools/plot/plottery/plots/std_dev.py +++ b/src/ert/gui/tools/plot/plottery/plots/std_dev.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any import matplotlib.pyplot as plt import numpy as np @@ -22,9 +22,9 @@ def plot( self, figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + ensemble_to_data_map: dict[EnsembleObject, pd.DataFrame], observation_data: pd.DataFrame, - std_dev_data: Dict[str, npt.NDArray[np.float32]], + std_dev_data: dict[str, npt.NDArray[np.float32]], ) -> None: ensemble_count = len(plot_context.ensembles()) layer = plot_context.layer diff --git a/src/ert/gui/tools/plot/widgets/clearable_line_edit.py b/src/ert/gui/tools/plot/widgets/clearable_line_edit.py index f24917eb287..63cfb1120be 100644 --- a/src/ert/gui/tools/plot/widgets/clearable_line_edit.py +++ b/src/ert/gui/tools/plot/widgets/clearable_line_edit.py @@ -1,5 +1,3 @@ -from typing import Optional - from qtpy.QtCore import QSize, Qt from qtpy.QtGui import QColor, QFocusEvent, QIcon, QKeyEvent, QResizeEvent from qtpy.QtWidgets import QLineEdit, QPushButton, QStyle @@ -42,7 +40,7 @@ def minimumSizeHint(self) -> QSize: size = QLineEdit.minimumSizeHint(self) return QSize(size.width() + self._clear_button.width() + 3, size.height()) - def resizeEvent(self, a0: Optional[QResizeEvent]) -> None: + def resizeEvent(self, a0: QResizeEvent | None) -> None: right = self.rect().right() style = self.style() assert style is not None @@ -72,16 +70,16 @@ def hidePlaceHolder(self) -> None: palette.setColor(self.foregroundRole(), self._active_color) self.setPalette(palette) - def focusInEvent(self, a0: Optional[QFocusEvent]) -> None: + def focusInEvent(self, a0: QFocusEvent | None) -> None: QLineEdit.focusInEvent(self, a0) self.hidePlaceHolder() - def focusOutEvent(self, a0: Optional[QFocusEvent]) -> None: + def focusOutEvent(self, a0: QFocusEvent | None) -> None: QLineEdit.focusOutEvent(self, a0) if not QLineEdit.text(self): self.showPlaceholder() - def keyPressEvent(self, a0: Optional[QKeyEvent]) -> None: + def keyPressEvent(self, a0: QKeyEvent | None) -> None: if a0 is not None and a0.key() == Qt.Key.Key_Escape: self.clear() self.clearFocus() @@ -89,7 +87,7 @@ def keyPressEvent(self, a0: Optional[QKeyEvent]) -> None: QLineEdit.keyPressEvent(self, a0) - def setText(self, a0: Optional[str]) -> None: + def setText(self, a0: str | None) -> None: self.hidePlaceHolder() QLineEdit.setText(self, a0) diff --git a/src/ert/gui/tools/plot/widgets/copy_style_to_dialog.py b/src/ert/gui/tools/plot/widgets/copy_style_to_dialog.py index 90feac8cea8..059a5500566 100644 --- a/src/ert/gui/tools/plot/widgets/copy_style_to_dialog.py +++ b/src/ert/gui/tools/plot/widgets/copy_style_to_dialog.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any from qtpy.QtGui import QIcon from qtpy.QtWidgets import ( @@ -24,9 +24,9 @@ class CopyStyleToDialog(QDialog): def __init__( self, - parent: Optional[QWidget], + parent: QWidget | None, current_key: Any, - key_defs: List[PlotApiKeyDefinition], + key_defs: list[PlotApiKeyDefinition], ): QWidget.__init__(self, parent) self.setMinimumWidth(450) @@ -66,10 +66,10 @@ def __init__( layout.addRow(button_layout) - def getSelectedKeys(self) -> List[str]: + def getSelectedKeys(self) -> list[str]: return self._list_model.getSelectedItems() - def filterSettingsChanged(self, item: Dict[str, bool]) -> None: + def filterSettingsChanged(self, item: dict[str, bool]) -> None: for value, visible in item.items(): self._list_model.setFilterOnMetadata("data_origin", value, visible) self._cl.modelChanged() diff --git a/src/ert/gui/tools/plot/widgets/custom_date_edit.py b/src/ert/gui/tools/plot/widgets/custom_date_edit.py index 5cf0ba19f86..27779469f46 100644 --- a/src/ert/gui/tools/plot/widgets/custom_date_edit.py +++ b/src/ert/gui/tools/plot/widgets/custom_date_edit.py @@ -1,5 +1,4 @@ import datetime -from typing import Optional, Union from qtpy.QtCore import QDate from qtpy.QtGui import QIcon @@ -44,7 +43,7 @@ def __init__(self) -> None: self._calendar_widget.activated.connect(self.setDate) - def setDate(self, date: Union[datetime.date, QDate]) -> None: + def setDate(self, date: datetime.date | QDate) -> None: if isinstance(date, datetime.date): date = QDate(date.year, date.month, date.day) # type: ignore @@ -53,7 +52,7 @@ def setDate(self, date: Union[datetime.date, QDate]) -> None: else: self._line_edit.setText("") - def date(self) -> Optional[datetime.date]: + def date(self) -> datetime.date | None: date_string = self._line_edit.text() if len(str(date_string).strip()) > 0: date = QDate.fromString(date_string, "yyyy-MM-dd") diff --git a/src/ert/gui/tools/plot/widgets/filter_popup.py b/src/ert/gui/tools/plot/widgets/filter_popup.py index 5596ced5c73..fb4a0a91fce 100644 --- a/src/ert/gui/tools/plot/widgets/filter_popup.py +++ b/src/ert/gui/tools/plot/widgets/filter_popup.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING from qtpy.QtCore import QEvent, Qt, Signal from qtpy.QtGui import QCursor @@ -22,7 +22,7 @@ class FilterPopup(QDialog): filterSettingsChanged = Signal(dict) def __init__( - self, parent: Optional[QWidget], key_defs: List[PlotApiKeyDefinition] + self, parent: QWidget | None, key_defs: list[PlotApiKeyDefinition] ) -> None: QDialog.__init__( self, @@ -54,22 +54,22 @@ def __init__( self.setLayout(layout) self.adjustSize() - def addFilterItem(self, name: str, _id: str, value: bool = True) -> None: - self.filter_items[_id] = value + def addFilterItem(self, name: str, id_: str, value: bool = True) -> None: + self.filter_items[id_] = value check_box = QCheckBox(name) check_box.setChecked(value) check_box.setObjectName("FilterCheckBox") def toggleItem(checked: bool) -> None: - self.filter_items[_id] = checked + self.filter_items[id_] = checked self.filterSettingsChanged.emit(self.filter_items) check_box.toggled.connect(toggleItem) self.__layout.addWidget(check_box) - def leaveEvent(self, event: Optional[QEvent]) -> None: + def leaveEvent(self, event: QEvent | None) -> None: self.hide() QWidget.leaveEvent(self, event) diff --git a/src/ert/gui/tools/plot/widgets/filterable_kw_list_model.py b/src/ert/gui/tools/plot/widgets/filterable_kw_list_model.py index 2952df7cdf3..d573ed86935 100644 --- a/src/ert/gui/tools/plot/widgets/filterable_kw_list_model.py +++ b/src/ert/gui/tools/plot/widgets/filterable_kw_list_model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING from ert.gui.ertwidgets import SelectableListModel @@ -14,12 +14,12 @@ class FilterableKwListModel(SelectableListModel): SelectableListModel """ - def __init__(self, key_defs: List[PlotApiKeyDefinition]): + def __init__(self, key_defs: list[PlotApiKeyDefinition]): SelectableListModel.__init__(self, [k.key for k in key_defs]) self._key_defs = key_defs - self._metadata_filters: Dict[str, Dict[str, bool]] = {} + self._metadata_filters: dict[str, dict[str, bool]] = {} - def getList(self) -> List[str]: + def getList(self) -> list[str]: items = [] for item in self._key_defs: add = True diff --git a/src/ert/gui/tools/plugins/plugin_handler.py b/src/ert/gui/tools/plugins/plugin_handler.py index 4032d34bcba..fbced1ad385 100644 --- a/src/ert/gui/tools/plugins/plugin_handler.py +++ b/src/ert/gui/tools/plugins/plugin_handler.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterator +from collections.abc import Iterator +from typing import TYPE_CHECKING from .plugin import Plugin diff --git a/src/ert/gui/tools/plugins/plugin_runner.py b/src/ert/gui/tools/plugins/plugin_runner.py index 9ca21edcc3a..1fea02f6edc 100644 --- a/src/ert/gui/tools/plugins/plugin_runner.py +++ b/src/ert/gui/tools/plugins/plugin_runner.py @@ -1,7 +1,8 @@ from __future__ import annotations import time -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from _ert.threading import ErtThread from ert.config import CancelPluginException @@ -18,7 +19,7 @@ class PluginRunner: def __init__( - self, plugin: "Plugin", ert_config: ErtConfig, storage: LocalStorage + self, plugin: Plugin, ert_config: ErtConfig, storage: LocalStorage ) -> None: super().__init__() self.ert_config = ert_config @@ -29,7 +30,7 @@ def __init__( self.__result = None self._runner = WorkflowJobRunner(plugin.getWorkflowJob()) - self.poll_thread: Optional[ErtThread] = None + self.poll_thread: ErtThread | None = None def run(self) -> None: try: @@ -70,7 +71,7 @@ def run(self) -> None: print("Plugin cancelled before execution!") def __runWorkflowJob( - self, arguments: Optional[List[Any]], fixtures: Dict[str, Any] + self, arguments: list[Any] | None, fixtures: dict[str, Any] ) -> None: self.__result = self._runner.run(arguments, fixtures=fixtures) diff --git a/src/ert/gui/tools/plugins/plugins_tool.py b/src/ert/gui/tools/plugins/plugins_tool.py index 0fc351e77f8..519d120d959 100644 --- a/src/ert/gui/tools/plugins/plugins_tool.py +++ b/src/ert/gui/tools/plugins/plugins_tool.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from qtpy.QtGui import QIcon from qtpy.QtWidgets import QMenu @@ -52,7 +52,7 @@ def get_menu(self) -> QMenu: def trigger(self) -> None: self.notifier.emitErtChange() # plugin may have added new cases. - def get_plugin_runner(self, plugin_name: str) -> Optional[PluginRunner]: + def get_plugin_runner(self, plugin_name: str) -> PluginRunner | None: for pulgin, runner in self.__plugins.items(): if pulgin.getName() == plugin_name: return runner diff --git a/src/ert/gui/tools/plugins/process_job_dialog.py b/src/ert/gui/tools/plugins/process_job_dialog.py index 854de172132..38198f8cb89 100644 --- a/src/ert/gui/tools/plugins/process_job_dialog.py +++ b/src/ert/gui/tools/plugins/process_job_dialog.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import cast from qtpy.QtCore import QSize, Qt, Signal from qtpy.QtGui import QCloseEvent, QKeyEvent, QMovie @@ -25,7 +25,7 @@ class ProcessJobDialog(QDialog): closeButtonPressed = Signal() cancelConfirmed = Signal() - def __init__(self, title: str, parent: Optional[QWidget] = None) -> None: + def __init__(self, title: str, parent: QWidget | None = None) -> None: QDialog.__init__(self, parent) self.__parent = parent @@ -80,7 +80,7 @@ def __init__(self, title: str, parent: Optional[QWidget] = None) -> None: self.presentError.connect(self.__presentError) self.closeButtonPressed.connect(self.__confirmCancel) - self._msg_box: Optional[QMessageBox] = None + self._msg_box: QMessageBox | None = None def disableCloseButton(self) -> None: self.close_button.setEnabled(False) @@ -88,7 +88,7 @@ def disableCloseButton(self) -> None: def enableCloseButton(self) -> None: self.close_button.setEnabled(True) - def keyPressEvent(self, a0: Optional[QKeyEvent]) -> None: + def keyPressEvent(self, a0: QKeyEvent | None) -> None: # disallow pressing escape to close # when close button is not enabled if ( @@ -98,15 +98,15 @@ def keyPressEvent(self, a0: Optional[QKeyEvent]) -> None: ): QDialog.keyPressEvent(self, a0) - def closeEvent(self, a0: Optional[QCloseEvent]) -> None: + def closeEvent(self, a0: QCloseEvent | None) -> None: if a0 is not None: a0.ignore() self.closeButtonPressed.emit() def __createMsgBox( - self, title: Optional[str], message: Optional[str], details: str + self, title: str | None, message: str | None, details: str ) -> QMessageBox: - msg_box = QMessageBox(cast(Optional[QWidget], self.parent())) + msg_box = QMessageBox(cast(QWidget | None, self.parent())) msg_box.setText(title) msg_box.setInformativeText(message) @@ -122,7 +122,7 @@ def __createMsgBox( return msg_box def __presentInformation( - self, title: Optional[str], message: Optional[str], details: str + self, title: str | None, message: str | None, details: str ) -> None: self._msg_box = self.__createMsgBox(title, message, details) self._msg_box.setIcon(QMessageBox.Information) @@ -130,7 +130,7 @@ def __presentInformation( self._msg_box.exec_() def __presentError( - self, title: Optional[str], message: Optional[str], details: str + self, title: str | None, message: str | None, details: str ) -> None: self._msg_box = self.__createMsgBox(title, message, details) self._msg_box.setIcon(QMessageBox.Critical) diff --git a/src/ert/gui/tools/search_bar/search_bar.py b/src/ert/gui/tools/search_bar/search_bar.py index 51b2d62e36a..59bb5ad6547 100644 --- a/src/ert/gui/tools/search_bar/search_bar.py +++ b/src/ert/gui/tools/search_bar/search_bar.py @@ -1,12 +1,6 @@ from qtpy import QtCore from qtpy.QtGui import QBrush, QColor, QTextCharFormat, QTextCursor -from qtpy.QtWidgets import ( - QBoxLayout, - QHBoxLayout, - QLabel, - QLineEdit, - QPlainTextEdit, -) +from qtpy.QtWidgets import QBoxLayout, QHBoxLayout, QLabel, QLineEdit, QPlainTextEdit class SearchBar(QLineEdit): @@ -47,10 +41,10 @@ def search_bar_changed(self, value: str) -> None: ) def get_layout(self) -> QBoxLayout: - _layout = QHBoxLayout() - _layout.addWidget(self._label) - _layout.addWidget(self) - return _layout + layout = QHBoxLayout() + layout.addWidget(self._label) + layout.addWidget(self) + return layout def select_text(self, start: int, length: int) -> None: self._cursor.setPosition(start) diff --git a/src/ert/gui/tools/tool.py b/src/ert/gui/tools/tool.py index f977d05e3ae..fc24e4c6f07 100644 --- a/src/ert/gui/tools/tool.py +++ b/src/ert/gui/tools/tool.py @@ -1,5 +1,3 @@ -from typing import Optional - from qtpy.QtCore import QObject from qtpy.QtGui import QIcon from qtpy.QtWidgets import QAction @@ -17,7 +15,7 @@ def __init__( super().__init__() self.__icon = icon self.__name = name - self.__parent: Optional[QObject] = None + self.__parent: QObject | None = None self.__enabled = enabled self.__checkable = checkable self.__is_popup_menu = popup_menu @@ -37,11 +35,11 @@ def getName(self) -> str: def trigger(self) -> None: raise NotImplementedError() - def setParent(self, parent: Optional[QObject]) -> None: + def setParent(self, parent: QObject | None) -> None: self.__parent = parent self.__action.setParent(parent) - def parent(self) -> Optional[QObject]: + def parent(self) -> QObject | None: return self.__parent def isEnabled(self) -> bool: diff --git a/src/ert/gui/tools/workflows/run_workflow_widget.py b/src/ert/gui/tools/workflows/run_workflow_widget.py index a42c0936624..a258b45afbb 100644 --- a/src/ert/gui/tools/workflows/run_workflow_widget.py +++ b/src/ert/gui/tools/workflows/run_workflow_widget.py @@ -1,7 +1,8 @@ from __future__ import annotations import time -from typing import TYPE_CHECKING, Iterable, Optional +from collections.abc import Iterable +from typing import TYPE_CHECKING from qtpy.QtCore import QSize, Qt, Signal from qtpy.QtGui import QIcon, QMovie @@ -57,13 +58,13 @@ def __init__(self, config: ErtConfig, notifier: ErtNotifier): self.setLayout(layout) - self._running_workflow_dialog: Optional[WorkflowDialog] = None + self._running_workflow_dialog: WorkflowDialog | None = None self.workflowSucceeded.connect(self.workflowFinished) self.workflowFailed.connect(self.workflowFinishedWithFail) self.workflowKilled.connect(self.workflowStoppedByUser) - self._workflow_runner: Optional[WorkflowRunner] = None + self._workflow_runner: WorkflowRunner | None = None def createSpinWidget(self) -> QWidget: widget = QWidget() diff --git a/src/ert/gui/tools/workflows/workflow_dialog.py b/src/ert/gui/tools/workflows/workflow_dialog.py index 13f4aab91ad..0927b922f4c 100644 --- a/src/ert/gui/tools/workflows/workflow_dialog.py +++ b/src/ert/gui/tools/workflows/workflow_dialog.py @@ -1,5 +1,3 @@ -from typing import Optional - from qtpy.QtCore import Qt, Signal from qtpy.QtGui import QKeyEvent from qtpy.QtWidgets import ( @@ -16,7 +14,7 @@ class WorkflowDialog(QDialog): closeButtonPressed = Signal() def __init__( - self, title: str, widget: QWidget, parent: Optional[QWidget] = None + self, title: str, widget: QWidget, parent: QWidget | None = None ) -> None: QDialog.__init__(self, parent) @@ -53,7 +51,7 @@ def disableCloseButton(self) -> None: def enableCloseButton(self) -> None: self.close_button.setEnabled(True) - def keyPressEvent(self, a0: Optional[QKeyEvent]) -> None: + def keyPressEvent(self, a0: QKeyEvent | None) -> None: # disallow pressing escape to close # when close button is not enabled if ( diff --git a/src/ert/libres_facade.py b/src/ert/libres_facade.py index 0d54dd4c11f..b25862149a1 100644 --- a/src/ert/libres_facade.py +++ b/src/ert/libres_facade.py @@ -3,13 +3,12 @@ import asyncio import logging import warnings +from collections.abc import Callable, Iterable from multiprocessing.pool import ThreadPool from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Callable, - Iterable, ) import numpy as np @@ -148,7 +147,7 @@ def load_from_run_path( ensemble.refresh_ensemble_state() return loaded - def get_observations(self) -> "EnkfObs": + def get_observations(self) -> EnkfObs: return self.config.enkf_obs def get_data_key_for_obs_key(self, observation_key: str) -> str: @@ -230,7 +229,7 @@ def run_ertscript( # type: ignore @classmethod def from_config_file( cls, config_file: str, read_only: bool = False - ) -> "LibresFacade": + ) -> LibresFacade: with ErtPluginContext(): return cls( ErtConfig.with_plugins().from_file(config_file), diff --git a/src/ert/load_status.py b/src/ert/load_status.py index 2af23f78690..e31ccc3bf82 100644 --- a/src/ert/load_status.py +++ b/src/ert/load_status.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import NamedTuple, Optional +from typing import NamedTuple class LoadStatus(Enum): @@ -8,5 +8,5 @@ class LoadStatus(Enum): class LoadResult(NamedTuple): - status: Optional[LoadStatus] + status: LoadStatus | None message: str diff --git a/src/ert/logging/__init__.py b/src/ert/logging/__init__.py index 8c9f182cdc8..2075a0fc3f1 100644 --- a/src/ert/logging/__init__.py +++ b/src/ert/logging/__init__.py @@ -4,7 +4,7 @@ import sys from datetime import datetime from types import TracebackType -from typing import Any, Optional, Tuple, Type, Union +from typing import Any LOGGING_CONFIG = pathlib.Path(__file__).parent.resolve() / "logger.conf" STORAGE_LOG_CONFIG = pathlib.Path(__file__).parent.resolve() / "storage_log.conf" @@ -70,9 +70,7 @@ def formatMessage(record: logging.LogRecord) -> str: @staticmethod def formatException( - _: Union[ - Tuple[Type[BaseException], BaseException, Optional[TracebackType]], - Tuple[None, None, None], - ], + _: tuple[type[BaseException], BaseException, TracebackType | None] + | tuple[None, None, None], ) -> str: return "" diff --git a/src/ert/namespace.py b/src/ert/namespace.py index 82191277027..3c2f285caea 100644 --- a/src/ert/namespace.py +++ b/src/ert/namespace.py @@ -1,7 +1,7 @@ from __future__ import annotations import argparse -from typing import Callable, Optional +from collections.abc import Callable from ert.plugins.plugin_manager import ErtPluginManager @@ -16,5 +16,5 @@ class Namespace(argparse.Namespace): verbose: bool experimental_mode: bool logdir: str - experiment_name: Optional[str] = None - func: Callable[[Namespace, Optional[ErtPluginManager]], None] + experiment_name: str | None = None + func: Callable[[Namespace, ErtPluginManager | None], None] diff --git a/src/ert/plugins/__init__.py b/src/ert/plugins/__init__.py index 1993a57ffe9..2b95bce0b99 100644 --- a/src/ert/plugins/__init__.py +++ b/src/ert/plugins/__init__.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, ParamSpec +from typing import Any, ParamSpec from .plugin_manager import ( ErtPluginContext, diff --git a/src/ert/plugins/hook_implementations/forward_model_steps.py b/src/ert/plugins/hook_implementations/forward_model_steps.py index 4ddfca2747d..e1a4442cdd7 100644 --- a/src/ert/plugins/hook_implementations/forward_model_steps.py +++ b/src/ert/plugins/hook_implementations/forward_model_steps.py @@ -3,7 +3,7 @@ import subprocess from pathlib import Path from textwrap import dedent -from typing import Literal, Type +from typing import Literal import yaml @@ -577,7 +577,7 @@ def documentation() -> ForwardModelStepDocumentation | None: ) -_UpperCaseFMSteps: list[Type[ForwardModelStepPlugin]] = [ +_UpperCaseFMSteps: list[type[ForwardModelStepPlugin]] = [ CarefulCopyFile, CopyDirectory, CopyFile, @@ -600,7 +600,7 @@ def documentation() -> ForwardModelStepDocumentation | None: # executables with no validation. def _create_lowercase_fm_step_cls_with_only_executable( fm_step_name: str, executable: str -) -> Type[ForwardModelStepPlugin]: +) -> type[ForwardModelStepPlugin]: class _LowerCaseFMStep(ForwardModelStepPlugin): def __init__(self) -> None: super().__init__(name=fm_step_name, command=[executable]) @@ -612,7 +612,7 @@ def documentation() -> ForwardModelStepDocumentation | None: return _LowerCaseFMStep -_LowerCaseFMSteps: list[Type[ForwardModelStepPlugin]] = [] +_LowerCaseFMSteps: list[type[ForwardModelStepPlugin]] = [] for fm_step_subclass in _UpperCaseFMSteps: assert issubclass(fm_step_subclass, ForwardModelStepPlugin) inst = fm_step_subclass() # type: ignore @@ -624,7 +624,7 @@ def documentation() -> ForwardModelStepDocumentation | None: @plugin(name="ert") -def installable_forward_model_steps() -> list[Type[ForwardModelStepPlugin]]: +def installable_forward_model_steps() -> list[type[ForwardModelStepPlugin]]: return [*_UpperCaseFMSteps, *_LowerCaseFMSteps] diff --git a/src/ert/plugins/hook_implementations/jobs.py b/src/ert/plugins/hook_implementations/jobs.py index 85f5dd4c5a8..68f14dcbed2 100644 --- a/src/ert/plugins/hook_implementations/jobs.py +++ b/src/ert/plugins/hook_implementations/jobs.py @@ -1,5 +1,4 @@ import os -from typing import Dict, List from jinja2 import Template @@ -7,7 +6,7 @@ from ert.shared import ert_share_path -def _get_jobs_from_directories(directories: List[str]) -> Dict[str, str]: +def _get_jobs_from_directories(directories: list[str]) -> dict[str, str]: share_path = ert_share_path() directories = [ Template(directory).render(ERT_SHARE_PATH=share_path, ERT_UI_MODE="gui") @@ -27,7 +26,7 @@ def _get_jobs_from_directories(directories: List[str]) -> Dict[str, str]: @ert.plugin(name="ert") -def installable_workflow_jobs() -> Dict[str, str]: +def installable_workflow_jobs() -> dict[str, str]: directories = [ "{{ERT_SHARE_PATH}}/workflows/jobs/shell", "{{ERT_SHARE_PATH}}/workflows/jobs/internal-{{ERT_UI_MODE}}/config", diff --git a/src/ert/plugins/hook_implementations/workflows/disable_parameters.py b/src/ert/plugins/hook_implementations/workflows/disable_parameters.py index c5ee7b4d1a4..b76f8fad0ee 100644 --- a/src/ert/plugins/hook_implementations/workflows/disable_parameters.py +++ b/src/ert/plugins/hook_implementations/workflows/disable_parameters.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from ert.config.ert_script import ErtScript from ert.config.parsing.config_errors import ConfigValidationError @@ -13,7 +13,7 @@ def run(self, disable_parameters: str) -> None: raise NotImplementedError(DisableParametersUpdate.__doc__) @staticmethod - def validate(args: List[Any]) -> None: + def validate(args: list[Any]) -> None: raise ConfigValidationError( f"DISABLE_PARAMETERS is removed, use the UPDATE:FALSE " f"option to the parameter instead:" diff --git a/src/ert/plugins/hook_implementations/workflows/export_misfit_data.py b/src/ert/plugins/hook_implementations/workflows/export_misfit_data.py index a57058f1da3..f7e44887a0c 100644 --- a/src/ert/plugins/hook_implementations/workflows/export_misfit_data.py +++ b/src/ert/plugins/hook_implementations/workflows/export_misfit_data.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any import pandas as pd @@ -24,7 +24,7 @@ class ExportMisfitDataJob(ErtScript): """ def run( - self, ert_config: ErtConfig, ensemble: Ensemble, workflow_args: List[Any] + self, ert_config: ErtConfig, ensemble: Ensemble, workflow_args: list[Any] ) -> None: target_file = "misfit.hdf" if not workflow_args else workflow_args[0] diff --git a/src/ert/plugins/hook_implementations/workflows/export_runpath.py b/src/ert/plugins/hook_implementations/workflows/export_runpath.py index ff843f39ee0..86c1ca6a31d 100644 --- a/src/ert/plugins/hook_implementations/workflows/export_runpath.py +++ b/src/ert/plugins/hook_implementations/workflows/export_runpath.py @@ -34,7 +34,7 @@ class ExportRunpathJob(ErtScript): """ def run(self, ert_config: ErtConfig, workflow_args: list[Any]) -> None: - _args = " ".join(workflow_args).split() # Make sure args is a list of words + args = " ".join(workflow_args).split() # Make sure args is a list of words run_paths = Runpaths( jobname_format=ert_config.model_config.jobname_format_string, runpath_format=ert_config.model_config.runpath_format_string, @@ -44,7 +44,7 @@ def run(self, ert_config: ErtConfig, workflow_args: list[Any]) -> None: ) run_paths.write_runpath_list( *self.get_ranges( - _args, + args, ert_config.analysis_config.num_iterations, ert_config.model_config.num_realizations, ) diff --git a/src/ert/plugins/hook_implementations/workflows/misfit_preprocessor.py b/src/ert/plugins/hook_implementations/workflows/misfit_preprocessor.py index 17fb031f3c9..cb0d5a20476 100644 --- a/src/ert/plugins/hook_implementations/workflows/misfit_preprocessor.py +++ b/src/ert/plugins/hook_implementations/workflows/misfit_preprocessor.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from ert.config.ert_script import ErtScript from ert.config.parsing.config_errors import ConfigValidationError @@ -13,7 +13,7 @@ def run(self, disable_parameters: str) -> None: raise NotImplementedError(MisfitPreprocessor.__doc__) @staticmethod - def validate(args: List[Any]) -> None: + def validate(args: list[Any]) -> None: message = MisfitPreprocessor.__doc__ assert message is not None if args: diff --git a/src/ert/plugins/hook_specifications/__init__.py b/src/ert/plugins/hook_specifications/__init__.py index ef23b588eb6..890fa7acbc2 100644 --- a/src/ert/plugins/hook_specifications/__init__.py +++ b/src/ert/plugins/hook_specifications/__init__.py @@ -15,10 +15,7 @@ job_documentation, legacy_ertscript_workflow, ) -from .logging import ( - add_log_handle_to_root, - add_span_processor, -) +from .logging import add_log_handle_to_root, add_span_processor from .site_config import site_config_lines __all__ = [ diff --git a/src/ert/plugins/hook_specifications/forward_model_steps.py b/src/ert/plugins/hook_specifications/forward_model_steps.py index dcb04aa4f41..646add4fead 100644 --- a/src/ert/plugins/hook_specifications/forward_model_steps.py +++ b/src/ert/plugins/hook_specifications/forward_model_steps.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Type, no_type_check +from typing import TYPE_CHECKING, no_type_check from ert.plugins.plugin_manager import hook_specification @@ -12,7 +12,7 @@ @no_type_check @hook_specification def installable_forward_model_steps() -> ( - PluginResponse[list[Type[ForwardModelStepPlugin]]] + PluginResponse[list[type[ForwardModelStepPlugin]]] ): """ :return: List of forward model step plugins in the form of subclasses of the @@ -22,7 +22,7 @@ def installable_forward_model_steps() -> ( @no_type_check @hook_specification -def forward_model_configuration() -> PluginResponse[list[Type[ForwardModelStepPlugin]]]: +def forward_model_configuration() -> PluginResponse[list[type[ForwardModelStepPlugin]]]: """ :return: List of configurations to be merged to be provided to forward model steps. """ diff --git a/src/ert/plugins/hook_specifications/help_resources.py b/src/ert/plugins/hook_specifications/help_resources.py index e1e18995598..d6ebf3d5256 100644 --- a/src/ert/plugins/hook_specifications/help_resources.py +++ b/src/ert/plugins/hook_specifications/help_resources.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from ert.plugins.plugin_manager import hook_specification @@ -9,7 +9,7 @@ @hook_specification -def help_links() -> PluginResponse[Dict[str, str]]: # type: ignore +def help_links() -> PluginResponse[dict[str, str]]: # type: ignore """Have a look at the ingredients and offer your own. :return: Dictionary with link as values and link labels as keys diff --git a/src/ert/plugins/hook_specifications/jobs.py b/src/ert/plugins/hook_specifications/jobs.py index d60e57f0944..cf28ce10419 100644 --- a/src/ert/plugins/hook_specifications/jobs.py +++ b/src/ert/plugins/hook_specifications/jobs.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional, no_type_check +from typing import TYPE_CHECKING, no_type_check from ert.plugins.plugin_manager import hook_specification @@ -11,7 +11,7 @@ @no_type_check @hook_specification -def installable_jobs() -> PluginResponse[Dict[str, str]]: +def installable_jobs() -> PluginResponse[dict[str, str]]: """ :return: dict with job names as keys and path to config as value :rtype: PluginResponse with data as dict[str,str] @@ -20,7 +20,7 @@ def installable_jobs() -> PluginResponse[Dict[str, str]]: @no_type_check @hook_specification(firstresult=True) -def job_documentation(job_name: str) -> PluginResponse[Optional[Dict[str, str]]]: +def job_documentation(job_name: str) -> PluginResponse[dict[str, str] | None]: """ :return: If job_name is from your plugin return dict with documentation fields as keys and corresponding @@ -38,7 +38,7 @@ def job_documentation(job_name: str) -> PluginResponse[Optional[Dict[str, str]]] @no_type_check @hook_specification -def installable_workflow_jobs() -> PluginResponse[Dict[str, str]]: +def installable_workflow_jobs() -> PluginResponse[dict[str, str]]: """ :return: dict with workflow job names as keys and path to config as value """ diff --git a/src/ert/plugins/hook_specifications/site_config.py b/src/ert/plugins/hook_specifications/site_config.py index 6ce28393fb2..069ff0ede11 100644 --- a/src/ert/plugins/hook_specifications/site_config.py +++ b/src/ert/plugins/hook_specifications/site_config.py @@ -1,10 +1,8 @@ -from typing import List - from ert.plugins.plugin_manager import hook_specification @hook_specification -def site_config_lines() -> List[str]: # type: ignore +def site_config_lines() -> list[str]: # type: ignore """ :return: List of lines to append to site config file """ diff --git a/src/ert/plugins/plugin_manager.py b/src/ert/plugins/plugin_manager.py index 2df1fa26bb6..0a4ce86dd19 100644 --- a/src/ert/plugins/plugin_manager.py +++ b/src/ert/plugins/plugin_manager.py @@ -6,19 +6,10 @@ import shutil import tempfile from argparse import ArgumentParser +from collections.abc import Callable, Mapping, Sequence from itertools import chain from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, - Mapping, - Sequence, - Type, - TypeVar, - overload, -) +from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload import pluggy from opentelemetry.sdk.trace import TracerProvider @@ -86,7 +77,7 @@ def get_help_links(self) -> dict[str, Any]: @property def forward_model_steps( self, - ) -> list[Type[ForwardModelStepPlugin]]: + ) -> list[type[ForwardModelStepPlugin]]: fm_steps_listed = [ resp.data for resp in self.hook.installable_forward_model_steps() ] @@ -445,7 +436,7 @@ def _reset_environment(self) -> None: def __exit__( self, exception: BaseException, - exception_type: Type[BaseException], + exception_type: type[BaseException], traceback: TracebackType, ) -> None: self._reset_environment() diff --git a/src/ert/plugins/workflow_config.py b/src/ert/plugins/workflow_config.py index 84104b5d61b..af5c6704a24 100644 --- a/src/ert/plugins/workflow_config.py +++ b/src/ert/plugins/workflow_config.py @@ -5,7 +5,8 @@ import os import tempfile from argparse import ArgumentParser -from typing import Any, Callable, Type +from collections.abc import Callable +from typing import Any logger = logging.getLogger(__name__) @@ -20,7 +21,7 @@ def __init__(self) -> None: self._workflows: list[WorkflowConfig] = [] def add_workflow( - self, ert_script: Type[Any], name: str | None = None + self, ert_script: type[Any], name: str | None = None ) -> WorkflowConfig: """ @@ -52,7 +53,7 @@ class WorkflowConfig: """ def __init__( - self, ertscript_class: Type[Any], tmpdir: str, name: str | None = None + self, ertscript_class: type[Any], tmpdir: str, name: str | None = None ) -> None: """ :param ertscript_class: Class inheriting from ErtScript @@ -111,7 +112,7 @@ def category(self, category: str) -> None: self._category = category @staticmethod - def _get_func_name(func: Type[Any], name: str | None) -> str: + def _get_func_name(func: type[Any], name: str | None) -> str: return name if name else func.__name__ def _write_workflow_config(self, output_dir: str) -> str: @@ -122,6 +123,6 @@ def _write_workflow_config(self, output_dir: str) -> str: return file_path @staticmethod - def _get_source_package(module: Type[Any]) -> str: + def _get_source_package(module: type[Any]) -> str: base, _, _ = module.__module__.partition(".") return base diff --git a/src/ert/resources/forward_models/res/script/ecl_config.py b/src/ert/resources/forward_models/res/script/ecl_config.py index 2d37b44a2e0..b0bbeb9196c 100644 --- a/src/ert/resources/forward_models/res/script/ecl_config.py +++ b/src/ert/resources/forward_models/res/script/ecl_config.py @@ -4,7 +4,7 @@ import subprocess import sys from pathlib import Path -from typing import Any, Dict +from typing import Any import yaml @@ -45,7 +45,7 @@ def __init__( self, version: str, executable: str, - env: Dict[str, str], + env: dict[str, str], mpirun: str | None = None, ): self.version: str = version @@ -53,7 +53,7 @@ def __init__( raise OSError(f"The executable: '{executable}' can not be executed by user") self.executable: str = executable - self.env: Dict[str, str] = env + self.env: dict[str, str] = env self.mpirun: str | None = mpirun self.name: str = "simulator" @@ -121,11 +121,11 @@ def _get_version(self, version_arg: str | None) -> str: return version - def _get_env(self, version: str, exe_type: str) -> Dict[str, str]: - env: Dict[str, str] = {} + def _get_env(self, version: str, exe_type: str) -> dict[str, str]: + env: dict[str, str] = {} env.update(self._config.get(Keys.env, {})) - mpi_sim: Dict[str, Any] = self._config[Keys.versions][ + mpi_sim: dict[str, Any] = self._config[Keys.versions][ self._get_version(version) ][exe_type] env.update(mpi_sim.get(Keys.env, {})) @@ -134,7 +134,7 @@ def _get_env(self, version: str, exe_type: str) -> Dict[str, str]: def _get_sim(self, version: str | None, exe_type: str) -> Simulator: version = self._get_version(version) - binaries: Dict[str, str] = self._config[Keys.versions][version][exe_type] + binaries: dict[str, str] = self._config[Keys.versions][version][exe_type] mpirun = binaries[Keys.mpirun] if exe_type == Keys.mpi else None return Simulator( version, diff --git a/src/ert/resources/forward_models/res/script/ecl_run.py b/src/ert/resources/forward_models/res/script/ecl_run.py index a348b5d8f9a..561b920a165 100644 --- a/src/ert/resources/forward_models/res/script/ecl_run.py +++ b/src/ert/resources/forward_models/res/script/ecl_run.py @@ -276,7 +276,7 @@ def __init__( data_file = input_arg + ".DATA" if not os.path.isfile(data_file): - raise IOError(f"No such file: {data_file}") + raise OSError(f"No such file: {data_file}") (self.run_path, self.data_file) = os.path.split(data_file) (self.base_name, ext) = os.path.splitext(self.data_file) @@ -385,7 +385,7 @@ def execEclipse(self, eclrun_config=None) -> int: with pushd(self.run_path): if not os.path.exists(self.data_file): - raise IOError(f"Can not find data_file:{self.data_file}") + raise OSError(f"Can not find data_file:{self.data_file}") if not os.access(self.data_file, os.R_OK): raise OSError(f"Can not read data file:{self.data_file}") @@ -493,7 +493,7 @@ def readECLEND(self): errors = None bugs = None - with open(report_file, "r", encoding="utf-8") as filehandle: + with open(report_file, encoding="utf-8") as filehandle: for line in filehandle.readlines(): error_match = re.match(error_regexp, line) if error_match: @@ -515,7 +515,7 @@ def parseErrors(self) -> list[str]: error_e100_regexp = re.compile(error_pattern_e100, re.MULTILINE) error_e300_regexp = re.compile(error_pattern_e300, re.MULTILINE) slave_started_regexp = re.compile(slave_started_pattern, re.MULTILINE) - with open(self.prt_path, "r", encoding="utf-8") as filehandle: + with open(self.prt_path, encoding="utf-8") as filehandle: content = filehandle.read() for regexp in [error_e100_regexp, error_e300_regexp, slave_started_regexp]: diff --git a/src/ert/resources/forward_models/template_render.py b/src/ert/resources/forward_models/template_render.py index d805a49c3a4..34a2bd97b69 100755 --- a/src/ert/resources/forward_models/template_render.py +++ b/src/ert/resources/forward_models/template_render.py @@ -27,7 +27,7 @@ def load_data(filename): except yaml.YAMLError as err: json_err = str(err) - raise IOError( + raise OSError( f"{filename} is neither yaml (err_msg={yaml_err}) nor json (err_msg={json_err})" ) diff --git a/src/ert/resources/shell_scripts/careful_copy_file.py b/src/ert/resources/shell_scripts/careful_copy_file.py index 3a36c60f654..de740327e96 100755 --- a/src/ert/resources/shell_scripts/careful_copy_file.py +++ b/src/ert/resources/shell_scripts/careful_copy_file.py @@ -28,7 +28,7 @@ def careful_copy_file(src, target=None): print(f"Copying file '{src}' -> '{target_file}'") shutil.copyfile(src, target_file) else: - raise IOError(f"Input argument:'{src}' does not correspond to an existing file") + raise OSError(f"Input argument:'{src}' does not correspond to an existing file") if __name__ == "__main__": @@ -39,5 +39,5 @@ def careful_copy_file(src, target=None): careful_copy_file(src, target) else: careful_copy_file(src) - except IOError as e: + except OSError as e: sys.exit(f"CAREFUL_COPY_FILE failed with the following error: {e}") diff --git a/src/ert/resources/shell_scripts/copy_directory.py b/src/ert/resources/shell_scripts/copy_directory.py index 5a9ad6e3200..f5cd9de067d 100755 --- a/src/ert/resources/shell_scripts/copy_directory.py +++ b/src/ert/resources/shell_scripts/copy_directory.py @@ -20,7 +20,7 @@ def copy_directory(src_path, target_path): target_path = os.path.join(target_path, src_basename) shutil.copytree(src_path, target_path, dirs_exist_ok=True) else: - raise IOError( + raise OSError( f"Input argument:'{src_path}' " "does not correspond to an existing directory" ) @@ -31,7 +31,7 @@ def copy_directory(src_path, target_path): target_path = sys.argv[2] try: copy_directory(src_path, target_path) - except IOError as e: + except OSError as e: sys.exit( f"COPY_DIRECTORY failed with the following error: {''.join(e.args[0])}" ) diff --git a/src/ert/resources/shell_scripts/copy_file.py b/src/ert/resources/shell_scripts/copy_file.py index 75c7e007596..4b6da833136 100755 --- a/src/ert/resources/shell_scripts/copy_file.py +++ b/src/ert/resources/shell_scripts/copy_file.py @@ -26,7 +26,7 @@ def copy_file(src, target=None): print(f"Copying file '{src}' -> '{target_file}'") shutil.copyfile(src, target_file) else: - raise IOError(f"Input argument:'{src}' does not correspond to an existing file") + raise OSError(f"Input argument:'{src}' does not correspond to an existing file") if __name__ == "__main__": @@ -37,5 +37,5 @@ def copy_file(src, target=None): copy_file(src, target) else: copy_file(src) - except IOError as e: + except OSError as e: sys.exit(f"COPY_FILE failed with the following error: {e}") diff --git a/src/ert/resources/shell_scripts/delete_directory.py b/src/ert/resources/shell_scripts/delete_directory.py index a8ecea22b3d..718b9923d3e 100755 --- a/src/ert/resources/shell_scripts/delete_directory.py +++ b/src/ert/resources/shell_scripts/delete_directory.py @@ -52,7 +52,7 @@ def delete_directory(path): delete_empty_directory(os.path.join(root, _dir)) else: - raise IOError(f"Entry:'{path}' is not a directory") + raise OSError(f"Entry:'{path}' is not a directory") delete_empty_directory(path) else: @@ -63,5 +63,5 @@ def delete_directory(path): try: for d in sys.argv[1:]: delete_directory(d) - except IOError as e: + except OSError as e: sys.exit(f"DELETE_DIRECTORY failed with the following error: {e}") diff --git a/src/ert/resources/shell_scripts/delete_file.py b/src/ert/resources/shell_scripts/delete_file.py index 8a7594b7b39..880e006016c 100755 --- a/src/ert/resources/shell_scripts/delete_file.py +++ b/src/ert/resources/shell_scripts/delete_file.py @@ -16,7 +16,7 @@ def delete_file(filename): f"Sorry you are not owner of file:{filename} - not deleted\n" ) else: - raise IOError(f"Entry:'{filename}' is not a regular file") + raise OSError(f"Entry:'{filename}' is not a regular file") elif os.path.islink(filename): os.remove(filename) else: @@ -27,5 +27,5 @@ def delete_file(filename): try: for file in sys.argv[1:]: delete_file(file) - except IOError as e: + except OSError as e: sys.exit(f"DELETE_FILE failed with the following error: {e}") diff --git a/src/ert/resources/shell_scripts/make_directory.py b/src/ert/resources/shell_scripts/make_directory.py index 58a370a2d9e..a660f4aa4a3 100755 --- a/src/ert/resources/shell_scripts/make_directory.py +++ b/src/ert/resources/shell_scripts/make_directory.py @@ -22,5 +22,5 @@ def mkdir(path): path = sys.argv[1] try: mkdir(path) - except IOError as e: + except OSError as e: sys.exit(f"MAKE_DIRECTORY failed with the following error: {e}") diff --git a/src/ert/resources/shell_scripts/move_directory.py b/src/ert/resources/shell_scripts/move_directory.py index ae980c96c85..22b58460ce2 100755 --- a/src/ert/resources/shell_scripts/move_directory.py +++ b/src/ert/resources/shell_scripts/move_directory.py @@ -14,7 +14,7 @@ def move_directory(src_dir, target): shutil.rmtree(target) shutil.move(src_dir, target) else: - raise IOError(f"Input argument {src_dir} is not an existing directory") + raise OSError(f"Input argument {src_dir} is not an existing directory") if __name__ == "__main__": @@ -22,5 +22,5 @@ def move_directory(src_dir, target): target = sys.argv[2] try: move_directory(src, target) - except IOError as e: + except OSError as e: sys.exit(f"MOVE_DIRECTORY failed with the following error: {e}") diff --git a/src/ert/resources/shell_scripts/move_file.py b/src/ert/resources/shell_scripts/move_file.py index f5aa168d7f3..4c60a12b4ee 100755 --- a/src/ert/resources/shell_scripts/move_file.py +++ b/src/ert/resources/shell_scripts/move_file.py @@ -15,7 +15,7 @@ def move_file(src_file, target): target = os.path.join(target, os.path.basename(src_file)) shutil.move(src_file, target) else: - raise IOError(f"Input argument {src_file} is not an existing file") + raise OSError(f"Input argument {src_file} is not an existing file") if __name__ == "__main__": @@ -23,5 +23,5 @@ def move_file(src_file, target): target = sys.argv[2] try: move_file(src, target) - except IOError as e: + except OSError as e: sys.exit(f"MOVE_FILE failed with the following error: {e}") diff --git a/src/ert/resources/shell_scripts/symlink.py b/src/ert/resources/shell_scripts/symlink.py index d157c6b456d..edb24fe89a5 100755 --- a/src/ert/resources/shell_scripts/symlink.py +++ b/src/ert/resources/shell_scripts/symlink.py @@ -21,7 +21,7 @@ def symlink(target, link_name): target_check = os.path.join(link_path, target) if not os.path.exists(target_check): - raise IOError( + raise OSError( f"{target} (target) and {link_name} (link_name) requested, " f"which implies that {target_check} must exist, but it does not." ) @@ -37,5 +37,5 @@ def symlink(target, link_name): link_name = sys.argv[2] try: symlink(target, link_name) - except IOError as e: + except OSError as e: sys.exit(f"SYMLINK failed with the following error: {e}") diff --git a/src/ert/resources/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py b/src/ert/resources/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py index 5a847c6bc02..bf2c7ead78f 100644 --- a/src/ert/resources/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py +++ b/src/ert/resources/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py @@ -13,7 +13,7 @@ def load_args(filename, column_names=None): rows = 0 columns = 0 - with open(filename, "r", encoding="utf-8") as fileH: + with open(filename, encoding="utf-8") as fileH: for line in fileH.readlines(): rows += 1 columns = max(columns, len(line.split())) diff --git a/src/ert/run_arg.py b/src/ert/run_arg.py index 1282402ab16..0189c682749 100644 --- a/src/ert/run_arg.py +++ b/src/ert/run_arg.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING from ert.runpaths import Runpaths @@ -27,9 +27,9 @@ class RunArg: def create_run_arguments( runpaths: Runpaths, - active_realizations: Union[List[bool], npt.NDArray[np.bool_]], + active_realizations: list[bool] | npt.NDArray[np.bool_], ensemble: Ensemble, -) -> List[RunArg]: +) -> list[RunArg]: iteration = ensemble.iteration run_args = [] runpaths.set_ert_ensemble(ensemble.name) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 26390a52e0e..ccc58c60003 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -10,10 +10,11 @@ import uuid from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Generator, MutableSequence from contextlib import contextmanager from pathlib import Path from queue import SimpleQueue -from typing import TYPE_CHECKING, Generator, MutableSequence, Union, cast +from typing import TYPE_CHECKING, cast import numpy as np @@ -80,20 +81,20 @@ if TYPE_CHECKING: from ert.config import QueueConfig -StatusEvents = Union[ - FullSnapshotEvent, - SnapshotUpdateEvent, - EndEvent, - AnalysisEvent, - AnalysisStatusEvent, - AnalysisTimeEvent, - RunModelErrorEvent, - RunModelStatusEvent, - RunModelTimeEvent, - RunModelUpdateBeginEvent, - RunModelDataEvent, - RunModelUpdateEndEvent, -] +StatusEvents = ( + FullSnapshotEvent + | SnapshotUpdateEvent + | EndEvent + | AnalysisEvent + | AnalysisStatusEvent + | AnalysisTimeEvent + | RunModelErrorEvent + | RunModelStatusEvent + | RunModelTimeEvent + | RunModelUpdateBeginEvent + | RunModelDataEvent + | RunModelUpdateEndEvent +) class OutOfOrderSnapshotUpdateException(ValueError): @@ -508,7 +509,7 @@ async def run_monitor( EESnapshot, EESnapshotUpdate, ): - event = cast(Union[EESnapshot, EESnapshotUpdate], event) + event = cast(EESnapshot | EESnapshotUpdate, event) await asyncio.get_running_loop().run_in_executor( None, self.send_snapshot_event, diff --git a/src/ert/run_models/event.py b/src/ert/run_models/event.py index 6974aae11e1..424c8ec9235 100644 --- a/src/ert/run_models/event.py +++ b/src/ert/run_models/event.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from pathlib import Path -from typing import Optional from uuid import UUID from ert.analysis.event import DataSection @@ -35,7 +34,7 @@ class RunModelDataEvent(RunModelEvent): name: str data: DataSection - def write_as_csv(self, output_path: Optional[Path]) -> None: + def write_as_csv(self, output_path: Path | None) -> None: if output_path and self.data: self.data.to_csv(self.name, output_path / str(self.run_id)) @@ -44,7 +43,7 @@ def write_as_csv(self, output_path: Optional[Path]) -> None: class RunModelUpdateEndEvent(RunModelEvent): data: DataSection - def write_as_csv(self, output_path: Optional[Path]) -> None: + def write_as_csv(self, output_path: Path | None) -> None: if output_path and self.data: self.data.to_csv("Report", output_path / str(self.run_id)) @@ -52,8 +51,8 @@ def write_as_csv(self, output_path: Optional[Path]) -> None: @dataclass class RunModelErrorEvent(RunModelEvent): error_msg: str - data: Optional[DataSection] = None + data: DataSection | None = None - def write_as_csv(self, output_path: Optional[Path]) -> None: + def write_as_csv(self, output_path: Path | None) -> None: if output_path and self.data: self.data.to_csv("Report", output_path / str(self.run_id)) diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index 14d4723a941..60e24a2dc5a 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -9,17 +9,15 @@ import random import shutil from collections import defaultdict +from collections.abc import Callable, Mapping from dataclasses import dataclass from pathlib import Path from types import TracebackType from typing import ( TYPE_CHECKING, Any, - Callable, Literal, - Mapping, Protocol, - Type, ) import numpy as np @@ -33,11 +31,7 @@ from seba_sqlite import SqliteStorage from typing_extensions import TypedDict -from _ert.events import ( - EESnapshot, - EESnapshotUpdate, - Event, -) +from _ert.events import EESnapshot, EESnapshotUpdate, Event from ert.config import ErtConfig, ExtParamConfig from ert.ensemble_evaluator import EnsembleSnapshot, EvaluatorServerConfig from ert.runpaths import Runpaths @@ -144,7 +138,7 @@ class OptimalResult: @staticmethod def from_seba_optimal_result( o: seba_sqlite.sqlite_storage.OptimalResult | None = None, - ) -> "OptimalResult" | None: + ) -> OptimalResult | None: if o is None: return None @@ -313,7 +307,7 @@ def _handle_errors( ) -> None: fm_id = "b_{}_r_{}_s_{}_{}".format(batch, realization, simulation, fm_name) fm_logger = logging.getLogger("forward_models") - with open(error_path, "r", encoding="utf-8") as errors: + with open(error_path, encoding="utf-8") as errors: error_str = errors.read() error_hash = hash(error_str) @@ -341,7 +335,7 @@ def onerror( _: Callable[..., Any], path: str, sys_info: tuple[ - Type[BaseException], BaseException, TracebackType + type[BaseException], BaseException, TracebackType ], ) -> None: logging.getLogger(EVEREST).debug( @@ -528,7 +522,7 @@ def _setup_sim( ensemble: Ensemble, ) -> None: def _check_suffix( - ext_config: "ExtParamConfig", + ext_config: ExtParamConfig, key: str, assignment: dict[str, Any] | tuple[str, str] | str | int, ) -> None: @@ -564,11 +558,9 @@ def _check_suffix( if isinstance(ext_config, ExtParamConfig): if len(ext_config) != len(control.keys()): raise KeyError( - ( - f"Expected {len(ext_config)} variables for " - f"control {control_name}, " - f"received {len(control.keys())}." - ) + f"Expected {len(ext_config)} variables for " + f"control {control_name}, " + f"received {len(control.keys())}." ) for var_name, var_setting in control.items(): _check_suffix(ext_config, var_name, var_setting) @@ -637,7 +629,7 @@ def _slug(entity: str) -> str: self._delete_runpath(run_args) # gather results - results: list[dict[str, "npt.NDArray[np.float64]"]] = [] + results: list[dict[str, npt.NDArray[np.float64]]] = [] for sim_id, successful in enumerate(self.active_realizations): if not successful: logger.error(f"Simulation {sim_id} failed.") diff --git a/src/ert/run_models/iterated_ensemble_smoother.py b/src/ert/run_models/iterated_ensemble_smoother.py index 32ccbca0fb1..466034b342a 100644 --- a/src/ert/run_models/iterated_ensemble_smoother.py +++ b/src/ert/run_models/iterated_ensemble_smoother.py @@ -211,12 +211,10 @@ def run_experiment( ) else: raise ErtRunError( - ( - "Iterated ensemble smoother stopped: " - "maximum number of iteration retries " - f"({self.num_retries_per_iter} retries) reached " - f"for iteration {prior_iter}" - ) + "Iterated ensemble smoother stopped: " + "maximum number of iteration retries " + f"({self.num_retries_per_iter} retries) reached " + f"for iteration {prior_iter}" ) prior = posterior diff --git a/src/ert/run_models/model_factory.py b/src/ert/run_models/model_factory.py index 68cb4a75d56..5f070bf0eca 100644 --- a/src/ert/run_models/model_factory.py +++ b/src/ert/run_models/model_factory.py @@ -2,7 +2,7 @@ import logging from queue import SimpleQueue -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING import numpy as np @@ -212,7 +212,7 @@ def _setup_ensemble_smoother( ) -def _determine_restart_info(args: Namespace) -> Tuple[bool, str]: +def _determine_restart_info(args: Namespace) -> tuple[bool, str]: """Handles differences in configuration between CLI and GUI. Returns diff --git a/src/ert/run_models/single_test_run.py b/src/ert/run_models/single_test_run.py index e1d4337aa9c..eda9cb97dab 100644 --- a/src/ert/run_models/single_test_run.py +++ b/src/ert/run_models/single_test_run.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from ert.config import ErtConfig from ert.run_models import EnsembleExperiment @@ -27,7 +27,7 @@ def __init__( self, ensemble_name: str, experiment_name: str, - random_seed: Optional[int], + random_seed: int | None, config: ErtConfig, storage: Storage, status_queue: SimpleQueue[StatusEvents], @@ -54,5 +54,5 @@ def description(cls) -> str: return "Sample parameters → evaluate single realization" @classmethod - def group(cls) -> Optional[str]: + def group(cls) -> str | None: return SINGLE_TEST_RUN_GROUP diff --git a/src/ert/runpaths.py b/src/ert/runpaths.py index a8c04e23155..d3a239e5b83 100644 --- a/src/ert/runpaths.py +++ b/src/ert/runpaths.py @@ -1,5 +1,5 @@ +from collections.abc import Iterable from pathlib import Path -from typing import Iterable from ert.substitutions import Substitutions diff --git a/src/ert/scheduler/driver.py b/src/ert/scheduler/driver.py index 159b5465f7f..3f935f591ff 100644 --- a/src/ert/scheduler/driver.py +++ b/src/ert/scheduler/driver.py @@ -4,8 +4,8 @@ import logging import shlex from abc import ABC, abstractmethod +from collections.abc import Iterable from pathlib import Path -from typing import Iterable from .event import Event @@ -103,7 +103,7 @@ async def _execute_with_retry( error_on_msgs: Iterable[str] = (), log_to_debug: bool | None = True, ) -> tuple[bool, str]: - _logger = driverlogger or logging.getLogger(__name__) + logger = driverlogger or logging.getLogger(__name__) error_message: str | None = None for i in range(total_attempts): @@ -127,14 +127,14 @@ async def _execute_with_retry( ) if process.returncode == 0: if retry_on_empty_stdout and not stdout: - _logger.warning( + logger.warning( f'Command "{shlex.join(cmd_with_args)}" gave exit code 0 but empty stdout, ' "will retry. " f'stderr: "{stderr.decode(errors="ignore").strip() or ""}"' ) else: if log_to_debug: - _logger.debug( + logger.debug( f'Command "{shlex.join(cmd_with_args)}" succeeded with {outputs}' ) return True, stdout.decode(errors="ignore").strip() @@ -152,7 +152,7 @@ async def _execute_with_retry( error_message = outputs elif process.returncode in accept_codes: if log_to_debug: - _logger.debug( + logger.debug( f'Command "{shlex.join(cmd_with_args)}" succeeded with {outputs}' ) return True, stderr.decode(errors="ignore").strip() @@ -160,7 +160,7 @@ async def _execute_with_retry( error_message = ( f'Command "{shlex.join(cmd_with_args)}" failed with {outputs}' ) - _logger.error(error_message) + logger.error(error_message) return False, error_message if i < (total_attempts - 1): await asyncio.sleep(retry_interval) @@ -168,5 +168,5 @@ async def _execute_with_retry( f'Command "{shlex.join(cmd_with_args)}" failed after {total_attempts} attempts ' f"with {outputs}" ) - _logger.error(error_message) + logger.error(error_message) return False, error_message diff --git a/src/ert/scheduler/event.py b/src/ert/scheduler/event.py index 7ba4af80ac9..b25dcd1874d 100644 --- a/src/ert/scheduler/event.py +++ b/src/ert/scheduler/event.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Union @dataclass @@ -17,4 +16,4 @@ class FinishedEvent: exec_hosts: str = "-" -Event = Union[StartedEvent, FinishedEvent] +Event = StartedEvent | FinishedEvent diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index 3a68557ee10..53b33b4e61b 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -5,7 +5,7 @@ import logging import time from contextlib import suppress -from enum import Enum +from enum import StrEnum from pathlib import Path from typing import TYPE_CHECKING, Any @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) -class JobState(str, Enum): +class JobState(StrEnum): WAITING = "WAITING" SUBMITTING = "SUBMITTING" PENDING = "PENDING" diff --git a/src/ert/scheduler/local_driver.py b/src/ert/scheduler/local_driver.py index cbee3f2ed64..5936539e381 100644 --- a/src/ert/scheduler/local_driver.py +++ b/src/ert/scheduler/local_driver.py @@ -5,9 +5,9 @@ import logging import signal from asyncio.subprocess import Process +from collections.abc import MutableMapping from contextlib import suppress from pathlib import Path -from typing import MutableMapping, Optional, Set from .driver import SIGNAL_OFFSET, Driver from .event import FinishedEvent, StartedEvent @@ -21,7 +21,7 @@ class LocalDriver(Driver): def __init__(self) -> None: super().__init__() self._tasks: MutableMapping[int, asyncio.Task[None]] = {} - self._sent_finished_events: Set[int] = set() + self._sent_finished_events: set[int] = set() async def submit( self, @@ -30,9 +30,9 @@ async def submit( /, *args: str, name: str = "dummy", - runpath: Optional[Path] = None, - num_cpu: Optional[int] = 1, - realization_memory: Optional[int] = 0, + runpath: Path | None = None, + num_cpu: int | None = 1, + realization_memory: int | None = 0, activate_script: str = "", ) -> None: self._tasks[iens] = asyncio.create_task(self._run(iens, executable, *args)) @@ -117,7 +117,7 @@ async def _kill(proc: Process) -> int: try: proc.terminate() await asyncio.wait_for(proc.wait(), _TERMINATE_TIMEOUT) - except asyncio.TimeoutError: + except TimeoutError: proc.kill() except ProcessLookupError: # This will happen if the subprocess has not yet started diff --git a/src/ert/scheduler/lsf_driver.py b/src/ert/scheduler/lsf_driver.py index ce691e9c165..6a0960c4264 100644 --- a/src/ert/scheduler/lsf_driver.py +++ b/src/ert/scheduler/lsf_driver.py @@ -9,16 +9,12 @@ import shutil import stat import time +from collections.abc import Iterable, Mapping, MutableMapping, Sequence from dataclasses import dataclass from pathlib import Path from tempfile import NamedTemporaryFile from typing import ( - Iterable, Literal, - Mapping, - MutableMapping, - Sequence, - Type, cast, get_args, ) @@ -79,7 +75,7 @@ class RunningJob: FinishedJobSuccess | FinishedJobFailure | QueuedJob | RunningJob | IgnoredJobstates ) -_STATE_ORDER: dict[Type[AnyJob], int] = { +_STATE_ORDER: dict[type[AnyJob], int] = { IgnoredJobstates: -1, QueuedJob: 0, RunningJob: 1, diff --git a/src/ert/scheduler/openpbs_driver.py b/src/ert/scheduler/openpbs_driver.py index ad63139c31f..03683611004 100644 --- a/src/ert/scheduler/openpbs_driver.py +++ b/src/ert/scheduler/openpbs_driver.py @@ -5,9 +5,10 @@ import logging import shlex import shutil +from collections.abc import Mapping, MutableMapping from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal, Mapping, MutableMapping, Type, cast, get_type_hints +from typing import Any, Literal, cast, get_type_hints from .driver import Driver, FailedSubmit, create_submit_script from .event import Event, FinishedEvent, StartedEvent @@ -62,7 +63,7 @@ class FinishedJob: AnyJob = FinishedJob | QueuedJob | RunningJob | IgnoredJobstates -_STATE_ORDER: dict[Type[AnyJob], int] = { +_STATE_ORDER: dict[type[AnyJob], int] = { IgnoredJobstates: -1, QueuedJob: 0, RunningJob: 1, diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 60939796c3a..a1610930b26 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -6,19 +6,11 @@ import time import traceback from collections import defaultdict +from collections.abc import Iterable, MutableMapping, Sequence from contextlib import suppress from dataclasses import asdict from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - MutableMapping, - Optional, - Sequence, - Union, -) +from typing import TYPE_CHECKING, Any import orjson from pydantic.dataclasses import dataclass @@ -39,12 +31,12 @@ @dataclass class _JobsJson: - ens_id: Optional[str] + ens_id: str | None real_id: int - dispatch_url: Optional[str] - ee_token: Optional[str] - ee_cert_path: Optional[str] - experiment_id: Optional[str] + dispatch_url: str | None + ee_token: str | None + ee_cert_path: str | None + experiment_id: str | None class SubmitSleeper: @@ -68,17 +60,17 @@ class Scheduler: def __init__( self, driver: Driver, - realizations: Optional[Sequence[Realization]] = None, - manifest_queue: Optional[asyncio.Queue[Event]] = None, - ensemble_evaluator_queue: Optional[asyncio.Queue[Event]] = None, + realizations: Sequence[Realization] | None = None, + manifest_queue: asyncio.Queue[Event] | None = None, + ensemble_evaluator_queue: asyncio.Queue[Event] | None = None, *, max_submit: int = 1, max_running: int = 1, submit_sleep: float = 0.0, - ens_id: Optional[str] = None, - ee_uri: Optional[str] = None, - ee_cert: Optional[str] = None, - ee_token: Optional[str] = None, + ens_id: str | None = None, + ee_uri: str | None = None, + ee_cert: str | None = None, + ee_token: str | None = None, ) -> None: self.driver = driver self._ensemble_evaluator_queue = ensemble_evaluator_queue @@ -86,7 +78,7 @@ def __init__( self._job_tasks: MutableMapping[int, asyncio.Task[None]] = {} - self.submit_sleep_state: Optional[SubmitSleeper] = None + self.submit_sleep_state: SubmitSleeper | None = None if submit_sleep > 0: self.submit_sleep_state = SubmitSleeper(submit_sleep) @@ -114,7 +106,7 @@ def __init__( self._ee_cert = ee_cert self._ee_token = ee_token - self.checksum: Dict[str, Dict[str, Any]] = {} + self.checksum: dict[str, dict[str, Any]] = {} def kill_all_jobs(self) -> None: assert self._loop @@ -178,8 +170,8 @@ def set_realization(self, realization: Realization) -> None: def is_active(self) -> bool: return any(not task.done() for task in self._job_tasks.values()) - def count_states(self) -> Dict[JobState, int]: - counts: Dict[JobState, int] = defaultdict(int) + def count_states(self) -> dict[JobState, int]: + counts: dict[JobState, int] = defaultdict(int) for job in self._jobs.values(): counts[job.state] += 1 return counts @@ -227,10 +219,8 @@ async def _monitor_and_handle_tasks( ) ) logger.error( - ( - f"Exception in scheduler task {task.get_name()}: {task_exception}\n" - f"Traceback: {exc_traceback}" - ) + f"Exception in scheduler task {task.get_name()}: {task_exception}\n" + f"Traceback: {exc_traceback}" ) if task in scheduling_tasks: await self._cancel_job_tasks() @@ -251,7 +241,7 @@ async def _monitor_and_handle_tasks( async def execute( self, min_required_realizations: int = 0, - ) -> Union[Id.ENSEMBLE_SUCCEEDED_TYPE, Id.ENSEMBLE_CANCELLED_TYPE]: + ) -> Id.ENSEMBLE_SUCCEEDED_TYPE | Id.ENSEMBLE_CANCELLED_TYPE: scheduling_tasks = [ asyncio.create_task(self._publisher(), name="publisher_task"), asyncio.create_task( @@ -329,7 +319,7 @@ async def _process_event_queue(self) -> None: # Any event implies the job has at least started job.started.set() - if isinstance(event, (StartedEvent, FinishedEvent)) and event.exec_hosts: + if isinstance(event, StartedEvent | FinishedEvent) and event.exec_hosts: self._jobs[event.iens].exec_hosts = event.exec_hosts if ( diff --git a/src/ert/scheduler/slurm_driver.py b/src/ert/scheduler/slurm_driver.py index 4d28cb784ad..5258b25ba93 100644 --- a/src/ert/scheduler/slurm_driver.py +++ b/src/ert/scheduler/slurm_driver.py @@ -7,15 +7,11 @@ import shlex import stat import time +from collections.abc import Iterator from dataclasses import dataclass from enum import Enum, auto from pathlib import Path from tempfile import NamedTemporaryFile -from typing import ( - Iterator, - Optional, - Tuple, -) from .driver import SIGNAL_OFFSET, Driver, FailedSubmit, create_submit_script from .event import Event, FinishedEvent, StartedEvent @@ -38,8 +34,8 @@ class JobStatus(Enum): @dataclass class JobData: iens: int - exit_code: Optional[int] = None - status: Optional[JobStatus] = None + exit_code: int | None = None + status: JobStatus | None = None END_STATES = {JobStatus.FAILED, JobStatus.COMPLETED, JobStatus.CANCELLED} @@ -47,12 +43,12 @@ class JobData: @dataclass class JobInfo: - status: Optional[JobStatus] = None + status: JobStatus | None = None @dataclass class ScontrolInfo(JobInfo): - exit_code: Optional[int] = None + exit_code: int | None = None @dataclass @@ -70,14 +66,14 @@ def __init__( sacct_cmd: str = "sacct", scancel_cmd: str = "scancel", sbatch_cmd: str = "sbatch", - user: Optional[str] = None, - memory: Optional[str] = "", - realization_memory: Optional[int] = 0, - queue_name: Optional[str] = None, - memory_per_cpu: Optional[str] = None, - max_runtime: Optional[float] = None, + user: str | None = None, + memory: str | None = "", + realization_memory: int | None = 0, + queue_name: str | None = None, + memory_per_cpu: str | None = None, + max_runtime: float | None = None, squeue_timeout: float = 2, - project_code: Optional[str] = None, + project_code: str | None = None, activate_script: str = "", ) -> None: """ @@ -135,8 +131,8 @@ def __init__( def _submit_cmd( self, name: str = "dummy", - runpath: Optional[Path] = None, - num_cpu: Optional[int] = 1, + runpath: Path | None = None, + num_cpu: int | None = 1, ) -> list[str]: sbatch_with_args = [ str(self._sbatch), @@ -177,15 +173,15 @@ async def submit( /, *args: str, name: str = "dummy", - runpath: Optional[Path] = None, - num_cpu: Optional[int] = 1, - realization_memory: Optional[int] = 0, + runpath: Path | None = None, + num_cpu: int | None = 1, + realization_memory: int | None = 0, ) -> None: if runpath is None: runpath = Path.cwd() script = create_submit_script(runpath, executable, args, self.activate_script) - script_path: Optional[Path] = None + script_path: Path | None = None try: with NamedTemporaryFile( dir=runpath, @@ -324,7 +320,7 @@ async def _process_job_update(self, job_id: str, new_info: JobInfo) -> None: return self._jobs[job_id].status = new_state - event: Optional[Event] = None + event: Event | None = None if new_state == JobStatus.RUNNING: logger.debug(f"Realization {iens} is running") event = StartedEvent(iens=iens) @@ -360,9 +356,7 @@ async def _get_exit_code(self, job_id: str) -> int: return code return SLURM_FAILED_EXIT_CODE_FETCH - async def _poll_once_by_scontrol( - self, missing_job_id: str - ) -> Optional[ScontrolInfo]: + async def _poll_once_by_scontrol(self, missing_job_id: str) -> ScontrolInfo | None: if ( time.time() - self._scontrol_cache_timestamp < self._scontrol_required_cache_age @@ -380,7 +374,7 @@ async def _poll_once_by_scontrol( self._scontrol_cache_timestamp = time.time() return info - async def _run_scontrol(self, missing_job_id: str) -> Optional[ScontrolInfo]: + async def _run_scontrol(self, missing_job_id: str) -> ScontrolInfo | None: process = await asyncio.create_subprocess_exec( self._scontrol, "show", @@ -406,7 +400,7 @@ async def _run_scontrol(self, missing_job_id: str) -> Optional[ScontrolInfo]: ) return None - async def _run_sacct(self, missing_job_id: str) -> Optional[ScontrolInfo]: + async def _run_sacct(self, missing_job_id: str) -> ScontrolInfo | None: try: process = await asyncio.create_subprocess_exec( self._sacct, @@ -467,7 +461,7 @@ def _tail_textfile(file_path: Path, num_chars: int) -> str: return file.read()[-num_chars:] -def _parse_squeue_output(output: str) -> Iterator[Tuple[str, SqueueInfo]]: +def _parse_squeue_output(output: str) -> Iterator[tuple[str, SqueueInfo]]: for line in output.split("\n"): if line: id, status = line.split() diff --git a/src/ert/services/_base_service.py b/src/ert/services/_base_service.py index a83da55b30f..39e47802fbd 100644 --- a/src/ert/services/_base_service.py +++ b/src/ert/services/_base_service.py @@ -7,24 +7,14 @@ import signal import sys import threading +from collections.abc import Callable, Mapping, Sequence from logging import Logger, getLogger from pathlib import Path from select import PIPE_BUF, select from subprocess import Popen, TimeoutExpired from time import sleep from types import FrameType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Mapping, - Optional, - Self, - Sequence, - Type, - TypeVar, -) +from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar if TYPE_CHECKING: from inspect import Traceback @@ -36,7 +26,7 @@ SERVICE_CONF_PATHS: set[str] = set() -def cleanup_service_files(signum: int, frame: Optional[FrameType]) -> None: +def cleanup_service_files(signum: int, frame: FrameType | None) -> None: for file_path in SERVICE_CONF_PATHS: file = Path(file_path) if file.exists(): @@ -79,7 +69,7 @@ def __enter__(self) -> T: def __exit__( self, - exc_type: Type[BaseException], + exc_type: type[BaseException], exc_value: BaseException, traceback: Traceback, ) -> bool: @@ -161,7 +151,7 @@ def shutdown(self) -> int: self.join() return self._childproc.returncode - def _read_conn_info(self, proc: Popen[bytes]) -> Optional[str]: + def _read_conn_info(self, proc: Popen[bytes]) -> str | None: comm_buf = io.StringIO() first_iter = True while first_iter or proc.poll() is None: @@ -239,19 +229,19 @@ class BaseService: # initialisation is finished and it will try to read the JSON data. """ - _instance: Optional["BaseService"] = None + _instance: BaseService | None = None def __init__( self, exec_args: Sequence[str] = (), timeout: int = 120, conn_info: ConnInfo = None, - project: Optional[str] = None, + project: str | None = None, ): self._exec_args = exec_args self._timeout = timeout - self._proc: Optional[_Proc] = None + self._proc: _Proc | None = None self._conn_info: ConnInfo = conn_info self._conn_info_event = threading.Event() self._project = Path(project) if project is not None else Path.cwd() @@ -265,7 +255,7 @@ def __init__( ) @classmethod - def start_server(cls: Type[T], *args: Any, **kwargs: Any) -> _Context[T]: + def start_server(cls: type[T], *args: Any, **kwargs: Any) -> _Context[T]: if cls._instance is not None: raise RuntimeError("Server already running") cls._instance = obj = cls(*args, **kwargs) @@ -277,8 +267,8 @@ def start_server(cls: Type[T], *args: Any, **kwargs: Any) -> _Context[T]: def connect( cls, *, - project: Optional[os.PathLike[str]] = None, - timeout: Optional[int] = None, + project: os.PathLike[str] | None = None, + timeout: int | None = None, ) -> Self: if cls._instance is not None: cls._instance.wait_until_ready() @@ -303,7 +293,7 @@ def connect( raise TimeoutError("Server not started") @classmethod - def connect_or_start_server(cls: Type[T], *args: Any, **kwargs: Any) -> _Context[T]: + def connect_or_start_server(cls: type[T], *args: Any, **kwargs: Any) -> _Context[T]: with contextlib.suppress(TimeoutError): # Note that timeout==0 will bypass the loop in connect() and force # TimeoutError if there is no known existing instance @@ -311,7 +301,7 @@ def connect_or_start_server(cls: Type[T], *args: Any, **kwargs: Any) -> _Context # Server is not running. Start a new one return cls.start_server(*args, **kwargs) - def wait_until_ready(self, timeout: Optional[int] = None) -> bool: + def wait_until_ready(self, timeout: int | None = None) -> bool: if timeout is None: timeout = self._timeout diff --git a/src/ert/services/storage_service.py b/src/ert/services/storage_service.py index d1a1226e876..734ecce5c91 100644 --- a/src/ert/services/storage_service.py +++ b/src/ert/services/storage_service.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import Any, Mapping, Optional, Sequence, Tuple, Union +from collections.abc import Mapping, Sequence +from typing import Any import httpx import requests @@ -17,11 +18,11 @@ def __init__( self, exec_args: Sequence[str] = (), timeout: int = 120, - conn_info: Union[Mapping[str, Any], Exception, None] = None, - project: Optional[str] = None, + conn_info: Mapping[str, Any] | Exception | None = None, + project: str | None = None, verbose: bool = False, ): - self._url: Optional[str] = None + self._url: str | None = None exec_args = local_exec_args("storage") @@ -31,7 +32,7 @@ def __init__( super().__init__(exec_args, timeout, conn_info, project) - def fetch_auth(self) -> Tuple[str, Any]: + def fetch_auth(self) -> tuple[str, Any]: """ Returns a tuple of username and password, compatible with requests' `auth` kwarg. @@ -74,7 +75,7 @@ def fetch_url(self) -> str: ) @classmethod - def session(cls, timeout: Optional[int] = None) -> Client: + def session(cls, timeout: int | None = None) -> Client: """ Start a HTTP transaction with the server """ @@ -88,7 +89,7 @@ def session(cls, timeout: Optional[int] = None) -> Client: @classmethod async def async_session( cls, - timeout: Optional[int] = None, # noqa: ASYNC109 + timeout: int | None = None, # noqa: ASYNC109 ) -> httpx.AsyncClient: """ Start a HTTP transaction with the server diff --git a/src/ert/shared/_doc_utils/everest_jobs.py b/src/ert/shared/_doc_utils/everest_jobs.py index 065ba08c32c..cf184ce7d16 100644 --- a/src/ert/shared/_doc_utils/everest_jobs.py +++ b/src/ert/shared/_doc_utils/everest_jobs.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar from docutils import nodes from sphinx.util.docutils import SphinxDirective @@ -27,8 +27,8 @@ class _EverestDocumentation(SphinxDirective): def _generate_job_documentation( self, - docs: Dict[str, Any], - ) -> List[nodes.section]: + docs: dict[str, Any], + ) -> list[nodes.section]: if not docs: node = nodes.section(ids=[_escape_id("no-forward-models-category")]) node.append(nodes.literal_block(text="No forward model jobs installed")) @@ -62,7 +62,7 @@ class EverestForwardModelDocumentation(_EverestDocumentation): pm = EverestPluginManager() _JOBS: ClassVar[dict[str, Any]] = {**pm.get_documentation()} - def run(self) -> List[nodes.section]: + def run(self) -> list[nodes.section]: return self._generate_job_documentation( EverestForwardModelDocumentation._JOBS, ) diff --git a/src/ert/shared/_doc_utils/forward_model_documentation.py b/src/ert/shared/_doc_utils/forward_model_documentation.py index 715264f63fa..ddd03a4a285 100644 --- a/src/ert/shared/_doc_utils/forward_model_documentation.py +++ b/src/ert/shared/_doc_utils/forward_model_documentation.py @@ -1,7 +1,8 @@ from __future__ import annotations from argparse import ArgumentParser -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from docutils import nodes @@ -15,9 +16,9 @@ def __init__( category: str, job_source: str, description: str, - job_config_file: Optional[str], - parser: Optional[Callable[[], ArgumentParser]], - examples: Optional[str] = "", + job_config_file: str | None, + parser: Callable[[], ArgumentParser] | None, + examples: str | None = "", ) -> None: self.name = name self.job_source = job_source diff --git a/src/ert/shared/net_utils.py b/src/ert/shared/net_utils.py index 66c12aef6c9..fbcb80cebf7 100644 --- a/src/ert/shared/net_utils.py +++ b/src/ert/shared/net_utils.py @@ -1,7 +1,6 @@ import logging import random import socket -from typing import Optional from dns import exception, resolver, reversename @@ -47,8 +46,8 @@ def get_machine_name() -> str: def find_available_socket( - custom_host: Optional[str] = None, - custom_range: Optional[range] = None, + custom_host: str | None = None, + custom_range: range | None = None, will_close_then_reopen_socket: bool = False, ) -> socket.socket: """ @@ -135,7 +134,7 @@ def get_family(host: str) -> socket.AddressFamily: try: socket.inet_pton(socket.AF_INET6, host) return socket.AF_INET6 - except socket.error: + except OSError: return socket.AF_INET diff --git a/src/ert/shared/plugins/plugin_response.py b/src/ert/shared/plugins/plugin_response.py index 63beece9966..16f0d921041 100644 --- a/src/ert/shared/plugins/plugin_response.py +++ b/src/ert/shared/plugins/plugin_response.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from functools import wraps -from typing import Callable, Optional, ParamSpec, TypeVar +from typing import ParamSpec, TypeVar from ert.plugins.plugin_response import PluginMetadata, PluginResponse @@ -11,10 +12,10 @@ def plugin_response( plugin_name: str = "", -) -> Callable[[Callable[P, T]], Callable[P, Optional[PluginResponse[T]]]]: - def outer(func: Callable[P, T]) -> Callable[P, Optional[PluginResponse[T]]]: +) -> Callable[[Callable[P, T]], Callable[P, PluginResponse[T] | None]]: + def outer(func: Callable[P, T]) -> Callable[P, PluginResponse[T] | None]: @wraps(func) - def inner(*args: P.args, **kwargs: P.kwargs) -> Optional[PluginResponse[T]]: + def inner(*args: P.args, **kwargs: P.kwargs) -> PluginResponse[T] | None: response = func(*args, **kwargs) return ( PluginResponse(response, PluginMetadata(plugin_name, func.__name__)) diff --git a/src/ert/shared/storage/connection.py b/src/ert/shared/storage/connection.py index f72cb555895..431ec62c965 100644 --- a/src/ert/shared/storage/connection.py +++ b/src/ert/shared/storage/connection.py @@ -1,14 +1,14 @@ from __future__ import annotations import os -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any from ert.services import StorageService def get_info( - project_id: Optional[os.PathLike[str]] = None, -) -> Dict[str, Union[str, Tuple[str, Any]]]: + project_id: os.PathLike[str] | None = None, +) -> dict[str, str | tuple[str, Any]]: client = StorageService.connect(project=project_id) return { "baseurl": client.fetch_url(), diff --git a/src/ert/shared/storage/extraction.py b/src/ert/shared/storage/extraction.py index 3a6cb4845fa..c22ee9c6e35 100644 --- a/src/ert/shared/storage/extraction.py +++ b/src/ert/shared/storage/extraction.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Mapping, Union +from collections.abc import Mapping from ert.config.gen_kw_config import GenKwConfig from ert.storage import Experiment @@ -22,13 +22,13 @@ def create_priors( experiment: Experiment, -) -> Mapping[str, Dict[str, Union[str, float]]]: +) -> Mapping[str, dict[str, str | float]]: priors_dict = {} for group, priors in experiment.parameter_configuration.items(): if isinstance(priors, GenKwConfig): for func in priors.transform_functions: - prior: Dict[str, Union[str, float]] = { + prior: dict[str, str | float] = { "function": _PRIOR_NAME_MAP[func.transform_function_name], } for name, value in func.parameter_list.items(): diff --git a/src/ert/storage/__init__.py b/src/ert/storage/__init__.py index 5fc6a543e79..ccc8bc0bba0 100644 --- a/src/ert/storage/__init__.py +++ b/src/ert/storage/__init__.py @@ -2,7 +2,6 @@ import os from pathlib import Path -from typing import Union from ert.storage.local_ensemble import LocalEnsemble from ert.storage.local_experiment import LocalExperiment @@ -33,7 +32,7 @@ class ErtStorageException(Exception): def open_storage( - path: Union[str, os.PathLike[str]], mode: Union[ModeLiteral, Mode] = "r" + path: str | os.PathLike[str], mode: ModeLiteral | Mode = "r" ) -> Storage: try: return LocalStorage(Path(path), Mode(mode)) diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index 8d425c9652b..5c10872b57f 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -3,10 +3,11 @@ import contextlib import logging import os +from collections.abc import Iterable from datetime import datetime -from functools import lru_cache +from functools import cache, lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING from uuid import UUID import numpy as np @@ -86,7 +87,7 @@ def __init__( ) self._error_log_name = "error.json" - @lru_cache(maxsize=None) + @cache def create_realization_dir(realization: int) -> Path: return self._path / f"realization-{realization}" @@ -455,12 +456,12 @@ def _responses_exist_for_realization( return True path = self._realization_dir(realization) - def _has_response(_key: str) -> bool: - if _key in self.experiment.response_key_to_response_type: - _response_type = self.experiment.response_key_to_response_type[_key] - return (path / f"{_response_type}.parquet").exists() + def _has_response(key_: str) -> bool: + if key_ in self.experiment.response_key_to_response_type: + response_type = self.experiment.response_key_to_response_type[key_] + return (path / f"{response_type}.parquet").exists() - return (path / f"{_key}.parquet").exists() + return (path / f"{key_}.parquet").exists() if key: return _has_response(key) @@ -483,20 +484,20 @@ def _has_response(_key: str) -> bool: ) def _find_state(realization: int) -> set[RealizationStorageState]: - _state = set() + state = set() if self.has_failure(realization): failure = self.get_failure(realization) assert failure - _state.add(failure.type) + state.add(failure.type) if _responses_exist_for_realization(realization): - _state.add(RealizationStorageState.RESPONSES_LOADED) + state.add(RealizationStorageState.RESPONSES_LOADED) if _parameters_exist_for_realization(realization): - _state.add(RealizationStorageState.PARAMETERS_LOADED) + state.add(RealizationStorageState.PARAMETERS_LOADED) - if len(_state) == 0: - _state.add(RealizationStorageState.UNDEFINED) + if len(state) == 0: + state.add(RealizationStorageState.UNDEFINED) - return _state + return state return [_find_state(i) for i in range(self.ensemble_size)] @@ -522,7 +523,7 @@ def _load_dataset( group: str, realizations: int | np.int64 | npt.NDArray[np.int_] | None, ) -> xr.Dataset: - if isinstance(realizations, (int, np.int64)): + if isinstance(realizations, int | np.int64): return self._load_single_dataset(group, int(realizations)).isel( realizations=0, drop=True ) diff --git a/src/ert/storage/local_experiment.py b/src/ert/storage/local_experiment.py index c21ad0922a6..16ebef794b4 100644 --- a/src/ert/storage/local_experiment.py +++ b/src/ert/storage/local_experiment.py @@ -1,10 +1,11 @@ from __future__ import annotations import json +from collections.abc import Generator from datetime import datetime from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Generator +from typing import TYPE_CHECKING, Any from uuid import UUID import numpy as np @@ -12,12 +13,7 @@ import xtgeo from pydantic import BaseModel -from ert.config import ( - ExtParamConfig, - Field, - GenKwConfig, - SurfaceConfig, -) +from ert.config import ExtParamConfig, Field, GenKwConfig, SurfaceConfig from ert.config.parsing.context_values import ContextBoolEncoder from ert.config.response_config import ResponseConfig from ert.storage.mode import BaseMode, Mode, require_write @@ -226,7 +222,7 @@ def metadata(self) -> dict[str, Any]: path = self.mount_point / self._metadata_file if not path.exists(): raise ValueError(f"{self._metadata_file!s} does not exist") - with open(path, encoding="utf-8", mode="r") as f: + with open(path, encoding="utf-8") as f: return json.load(f) @property @@ -251,7 +247,7 @@ def parameter_info(self) -> dict[str, Any]: path = self.mount_point / self._parameter_file if not path.exists(): raise ValueError(f"{self._parameter_file!s} does not exist") - with open(path, encoding="utf-8", mode="r") as f: + with open(path, encoding="utf-8") as f: info = json.load(f) return info @@ -261,7 +257,7 @@ def response_info(self) -> dict[str, Any]: path = self.mount_point / self._responses_file if not path.exists(): raise ValueError(f"{self._responses_file!s} does not exist") - with open(path, encoding="utf-8", mode="r") as f: + with open(path, encoding="utf-8") as f: info = json.load(f) return info diff --git a/src/ert/storage/local_storage.py b/src/ert/storage/local_storage.py index b0fcca3f4c9..679d452e7c7 100644 --- a/src/ert/storage/local_storage.py +++ b/src/ert/storage/local_storage.py @@ -5,13 +5,14 @@ import logging import os import shutil +from collections.abc import Generator, MutableSequence from datetime import datetime from functools import cached_property from pathlib import Path from tempfile import NamedTemporaryFile from textwrap import dedent from types import TracebackType -from typing import Any, Generator, MutableSequence, Type +from typing import Any from uuid import UUID, uuid4 import polars @@ -23,11 +24,7 @@ from ert.shared import __version__ from ert.storage.local_ensemble import LocalEnsemble from ert.storage.local_experiment import LocalExperiment -from ert.storage.mode import ( - BaseMode, - Mode, - require_write, -) +from ert.storage.mode import BaseMode, Mode, require_write from ert.storage.realization_storage_state import RealizationStorageState logger = logging.getLogger(__name__) @@ -251,7 +248,7 @@ def __enter__(self) -> LocalStorage: def __exit__( self, exception: Exception, - exception_type: Type[Exception], + exception_type: type[Exception], traceback: TracebackType, ) -> None: self.close() diff --git a/src/ert/storage/migration/to5.py b/src/ert/storage/migration/to5.py index d7ad8b89767..1b7c79108f4 100644 --- a/src/ert/storage/migration/to5.py +++ b/src/ert/storage/migration/to5.py @@ -36,7 +36,7 @@ def migrate(path: Path) -> None: ) responses_file = experiment / "responses.json" - with open(responses_file, encoding="utf-8", mode="r") as f: + with open(responses_file, encoding="utf-8") as f: info = json.load(f) for key, values in list(info.items()): if values.get("_ert_kind") == "SummaryConfig" and not values.get("keys"): diff --git a/src/ert/storage/migration/to7.py b/src/ert/storage/migration/to7.py index 7168b892d68..c1036d0f71a 100644 --- a/src/ert/storage/migration/to7.py +++ b/src/ert/storage/migration/to7.py @@ -88,7 +88,7 @@ def _migrate_response_datasets(path: Path) -> None: experiment_id = exp_index["id"] responses_file = experiment / "responses.json" - with open(responses_file, encoding="utf-8", mode="r") as f: + with open(responses_file, encoding="utf-8") as f: responses_obj = json.load(f) assert ( diff --git a/src/ert/storage/migration/to8.py b/src/ert/storage/migration/to8.py index 4de17545c17..6c054997fcf 100644 --- a/src/ert/storage/migration/to8.py +++ b/src/ert/storage/migration/to8.py @@ -127,15 +127,15 @@ def _migrate_observations_to_grouped_parquet(path: Path) -> None: if not os.path.exists(experiment / "observations"): os.makedirs(experiment / "observations") - _obs_keys = os.listdir(os.path.join(experiment, "observations")) + obs_keys = os.listdir(os.path.join(experiment, "observations")) - if len(set(_obs_keys) - {"summary", "gen_data"}) == 0: + if len(set(obs_keys) - {"summary", "gen_data"}) == 0: # Observations are already migrated, likely from .to4 migrations continue obs_ds_infos = [ ObservationDatasetInfo.from_path(experiment / "observations" / p) - for p in _obs_keys + for p in obs_keys ] for response_type in ["gen_data", "summary"]: diff --git a/src/ert/storage/migration/to9.py b/src/ert/storage/migration/to9.py index 2335a51bd98..5d8e07c2399 100644 --- a/src/ert/storage/migration/to9.py +++ b/src/ert/storage/migration/to9.py @@ -28,9 +28,7 @@ def migrate(path: Path) -> None: with ( open(experiment / "index.json", encoding="utf-8") as f_experiment, - open( - experiment / "responses.json", mode="r", encoding="utf-8" - ) as f_responses, + open(experiment / "responses.json", encoding="utf-8") as f_responses, ): exp_index = json.load(f_experiment) experiment_id = exp_index["id"] diff --git a/src/ert/storage/mode.py b/src/ert/storage/mode.py index 596ffe2efa8..4eeb920b273 100644 --- a/src/ert/storage/mode.py +++ b/src/ert/storage/mode.py @@ -1,8 +1,9 @@ from __future__ import annotations -from enum import Enum +from collections.abc import Callable +from enum import StrEnum from functools import wraps -from typing import Callable, Concatenate, Literal, ParamSpec, TypeVar +from typing import Concatenate, Literal, ParamSpec, TypeVar ModeLiteral = Literal["r", "w"] @@ -14,7 +15,7 @@ class ModeError(ValueError): """ -class Mode(str, Enum): +class Mode(StrEnum): """Enumeration representing the access modes for storage interaction.""" READ = "r" diff --git a/src/ert/substitutions.py b/src/ert/substitutions.py index 887fb88d2d8..c8608b3d1c7 100644 --- a/src/ert/substitutions.py +++ b/src/ert/substitutions.py @@ -2,7 +2,8 @@ import logging import re -from typing import Any, Mapping, Optional +from collections.abc import Mapping +from typing import Any from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler from pydantic.json_schema import JsonSchemaValue @@ -123,7 +124,7 @@ def _substitute( return substituted_string -def _replace_strings(substitutions: Mapping[str, str], string: str) -> Optional[str]: +def _replace_strings(substitutions: Mapping[str, str], string: str) -> str | None: start = 0 parts = [] for match in _PATTERN.finditer(string): diff --git a/src/ert/validation/ensemble_realizations_argument.py b/src/ert/validation/ensemble_realizations_argument.py index 4220b00b65f..7280ca995e9 100644 --- a/src/ert/validation/ensemble_realizations_argument.py +++ b/src/ert/validation/ensemble_realizations_argument.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .range_string_argument import RangeStringArgument from .rangestring import rangestring_to_list @@ -14,7 +14,7 @@ class EnsembleRealizationsArgument(RangeStringArgument): ) def __init__( - self, ensemble: "Ensemble", max_value: Optional[int], **kwargs: bool + self, ensemble: "Ensemble", max_value: int | None, **kwargs: bool ) -> None: super().__init__(max_value, **kwargs) self.__ensemble = ensemble diff --git a/src/ert/validation/integer_argument.py b/src/ert/validation/integer_argument.py index 719a6b033fb..5b39b3e532b 100644 --- a/src/ert/validation/integer_argument.py +++ b/src/ert/validation/integer_argument.py @@ -1,5 +1,4 @@ import re -from typing import Optional from .argument_definition import ArgumentDefinition from .validation_status import ValidationStatus @@ -13,8 +12,8 @@ class IntegerArgument(ArgumentDefinition): def __init__( self, - from_value: Optional[int] = None, - to_value: Optional[int] = None, + from_value: int | None = None, + to_value: int | None = None, **kwargs: bool, ) -> None: super().__init__(**kwargs) diff --git a/src/ert/validation/range_string_argument.py b/src/ert/validation/range_string_argument.py index 6733d43428d..633b756f277 100644 --- a/src/ert/validation/range_string_argument.py +++ b/src/ert/validation/range_string_argument.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - from .active_range import ActiveRange from .argument_definition import ArgumentDefinition from .validation_status import ValidationStatus @@ -15,7 +13,7 @@ class RangeStringArgument(ArgumentDefinition): ) VALUE_NOT_IN_RANGE = "A value must be in the range from 0 to %d." - def __init__(self, max_value: Optional[int] = None, **kwargs: bool) -> None: + def __init__(self, max_value: int | None = None, **kwargs: bool) -> None: super().__init__(**kwargs) self.__max_value = max_value diff --git a/src/ert/validation/rangestring.py b/src/ert/validation/rangestring.py index 80827e6ca02..c67337fe0f8 100644 --- a/src/ert/validation/rangestring.py +++ b/src/ert/validation/rangestring.py @@ -6,7 +6,7 @@ The ranges can overlap. The end of each range is inclusive. """ -from typing import Collection +from collections.abc import Collection def mask_to_rangestring(mask: Collection[bool | int]) -> str: diff --git a/src/ert/validation/validation_status.py b/src/ert/validation/validation_status.py index d0b08a771eb..70d2306d02e 100644 --- a/src/ert/validation/validation_status.py +++ b/src/ert/validation/validation_status.py @@ -1,12 +1,9 @@ -from typing import Optional - - class ValidationStatus: def __init__(self) -> None: super().__init__() self.__fail = False self.__message = "" - self.__value: Optional[str] = None + self.__value: str | None = None def setFailed(self) -> None: self.__fail = True @@ -23,7 +20,7 @@ def message(self) -> str: def setValue(self, value: str) -> None: self.__value = value - def value(self) -> Optional[str]: + def value(self) -> str | None: return self.__value def __bool__(self) -> bool: diff --git a/src/ert/workflow_runner.py b/src/ert/workflow_runner.py index 8a6a7df4526..232ed13a8f9 100644 --- a/src/ert/workflow_runner.py +++ b/src/ert/workflow_runner.py @@ -3,7 +3,7 @@ import logging from concurrent import futures from concurrent.futures import Future -from typing import TYPE_CHECKING, Any, Optional, Self +from typing import TYPE_CHECKING, Any, Self from ert.config import ErtConfig, ErtScript, ExternalErtScript, Workflow, WorkflowJob @@ -15,13 +15,13 @@ class WorkflowJobRunner: def __init__(self, workflow_job: WorkflowJob): self.job = workflow_job self.__running = False - self.__script: Optional[ErtScript] = None + self.__script: ErtScript | None = None self.stop_on_fail = False def run( self, - arguments: Optional[list[Any]] = None, - fixtures: Optional[dict[str, Any]] = None, + arguments: list[Any] | None = None, + fixtures: dict[str, Any] | None = None, ) -> Any: if arguments is None: arguments = [] @@ -107,22 +107,22 @@ class WorkflowRunner: def __init__( self, workflow: Workflow, - storage: Optional[Storage] = None, - ensemble: Optional[Ensemble] = None, - ert_config: Optional[ErtConfig] = None, + storage: Storage | None = None, + ensemble: Ensemble | None = None, + ert_config: ErtConfig | None = None, ) -> None: self.__workflow = workflow self.storage = storage self.ensemble = ensemble self.ert_config = ert_config - self.__workflow_result: Optional[bool] = None + self.__workflow_result: bool | None = None self._workflow_executor = futures.ThreadPoolExecutor(max_workers=1) - self._workflow_job: Optional[Future[None]] = None + self._workflow_job: Future[None] | None = None self.__running = False self.__cancelled = False - self.__current_job: Optional[WorkflowJobRunner] = None + self.__current_job: WorkflowJobRunner | None = None self.__status: dict[str, dict[str, Any]] = {} def __enter__(self) -> Self: @@ -131,8 +131,8 @@ def __enter__(self) -> Self: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, traceback: Any, ) -> None: self.wait() @@ -216,7 +216,7 @@ def cancel(self) -> None: self.__cancelled = True self.wait() - def exception(self) -> Optional[BaseException]: + def exception(self) -> BaseException | None: if self._workflow_job is not None: return self._workflow_job.exception() return None @@ -229,7 +229,7 @@ def wait(self) -> None: [self._workflow_job], timeout=None, return_when=futures.FIRST_EXCEPTION ) - def workflowResult(self) -> Optional[bool]: + def workflowResult(self) -> bool | None: return self.__workflow_result def workflowReport(self) -> dict[str, dict[str, Any]]: diff --git a/src/everest/bin/config_branch_script.py b/src/everest/bin/config_branch_script.py index 5b8ae0f5fb5..6c31d3d25d2 100644 --- a/src/everest/bin/config_branch_script.py +++ b/src/everest/bin/config_branch_script.py @@ -2,7 +2,7 @@ from copy import deepcopy as copy from functools import partial from os.path import exists, join -from typing import Any, Dict, Optional, Tuple +from typing import Any from ruamel.yaml import YAML from seba_sqlite.database import Database as seba_db @@ -13,7 +13,7 @@ from everest.config_keys import ConfigKeys as CK -def _yaml_config(file_path: str, parser) -> Tuple[str, Optional[Dict[str, Any]]]: +def _yaml_config(file_path: str, parser) -> tuple[str, dict[str, Any] | None]: loaded_config = EverestConfig.load_file_with_argparser(file_path, parser) assert loaded_config is not None diff --git a/src/everest/bin/main.py b/src/everest/bin/main.py index 60ac014de3e..7e9154bba92 100644 --- a/src/everest/bin/main.py +++ b/src/everest/bin/main.py @@ -26,7 +26,7 @@ def __init__( default=argparse.SUPPRESS, help=None, ): - super(_DumpAction, self).__init__( + super().__init__( option_strings=option_strings, dest=dest, default=default, @@ -77,7 +77,7 @@ def _build_args_parser(): return arg_parser -class EverestMain(object): +class EverestMain: def __init__(self, args): parser = _build_args_parser() # Parse_args defaults to [1:] for args, but you need to diff --git a/src/everest/bin/utils.py b/src/everest/bin/utils.py index aba12687576..7c891f729cb 100644 --- a/src/everest/bin/utils.py +++ b/src/everest/bin/utils.py @@ -4,7 +4,7 @@ import traceback from dataclasses import dataclass, field from itertools import groupby -from typing import ClassVar, Dict, List, Tuple +from typing import ClassVar import colorama from colorama import Fore @@ -100,7 +100,7 @@ def _format_list(values): @dataclass class JobProgress: name: str - status: Dict[str, List[int]] = field( + status: dict[str, list[int]] = field( default_factory=lambda: { JOB_RUNNING: [], # contains running simulation numbers i.e [7,8,9] JOB_SUCCESS: [], # contains successful simulation numbers i.e [0,1,3,4] @@ -113,7 +113,7 @@ class JobProgress: JOB_FAILURE: Fore.RED, } - def _status_string(self, max_widths: Dict[str, int]) -> str: + def _status_string(self, max_widths: dict[str, int]) -> str: string = [] for state in [JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE]: number_of_simulations = len(self.status[state]) @@ -122,7 +122,7 @@ def _status_string(self, max_widths: Dict[str, int]) -> str: string.append(f"{color}{number_of_simulations:>{width}}{Fore.RESET}") return "/".join(string) - def progress_str(self, max_widths: Dict[str, int]) -> str: + def progress_str(self, max_widths: dict[str, int]) -> str: msg = "" for state in [JOB_SUCCESS, JOB_FAILURE]: simulations_list = _format_list(self.status[state]) @@ -299,7 +299,7 @@ def _clear(self): def run_detached_monitor( - server_context: Tuple[str, str, Tuple[str, str]], + server_context: tuple[str, str, tuple[str, str]], optimization_output_dir: str, show_all_jobs: bool = False, ): diff --git a/src/everest/config/control_config.py b/src/everest/config/control_config.py index cb6783717c2..689b5cb3508 100644 --- a/src/everest/config/control_config.py +++ b/src/everest/config/control_config.py @@ -1,9 +1,14 @@ from itertools import chain -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import ( + Annotated, + Any, + Literal, + Self, + TypeAlias, +) from pydantic import AfterValidator, BaseModel, ConfigDict, Field, model_validator from ropt.enums import PerturbationType, VariableType -from typing_extensions import Annotated, Self, TypeAlias from .control_variable_config import ( ControlVariableConfig, @@ -17,9 +22,9 @@ valid_range, ) -ControlVariable: TypeAlias = Union[ - List[ControlVariableConfig], List[ControlVariableGuessListConfig] -] +ControlVariable: TypeAlias = ( + list[ControlVariableConfig] | list[ControlVariableGuessListConfig] +) def _all_or_no_index(variables: ControlVariable) -> ControlVariable: @@ -54,7 +59,7 @@ class ControlConfig(BaseModel): AfterValidator(_all_or_no_index), AfterValidator(unique_items), ] = Field(description="List of control variables", min_length=1) - initial_guess: Optional[float] = Field( + initial_guess: float | None = Field( default=None, description=""" Initial guess for the control group all control variables with initial_guess not @@ -71,7 +76,7 @@ class ControlConfig(BaseModel): different control types. """, ) - enabled: Optional[bool] = Field( + enabled: bool | None = Field( default=True, description=""" If `True`, all variables in this control group will be optimized. If set to `False` @@ -86,7 +91,7 @@ class ControlConfig(BaseModel): scaled_range (default [0, 1]). """, ) - min: Optional[float] = Field( + min: float | None = Field( default=None, description=""" Defines left-side value in the control group range [min, max]. @@ -96,7 +101,7 @@ class ControlConfig(BaseModel): in the resulting [min, max] range """, ) - max: Optional[float] = Field( + max: float | None = Field( default=None, description=""" Defines right-side value in the control group range [min, max]. @@ -120,7 +125,7 @@ class ControlConfig(BaseModel): ranges might have unintended effects. """, ) - perturbation_magnitude: Optional[float] = Field( + perturbation_magnitude: float | None = Field( default=None, description=""" Specifies the perturbation magnitude for a set of controls of a certain type. @@ -136,18 +141,18 @@ class ControlConfig(BaseModel): NOTE: In most cases this should not be configured, and the default value should be used. """, ) - scaled_range: Annotated[ - Optional[Tuple[float, float]], AfterValidator(valid_range) - ] = Field( - default=None, - description=""" + scaled_range: Annotated[tuple[float, float] | None, AfterValidator(valid_range)] = ( + Field( + default=None, + description=""" Can be used to set the range of the control values after scaling (default = [0, 1]). This option has no effect if auto_scale is not set. """, + ) ) - sampler: Optional[SamplerConfig] = Field( + sampler: SamplerConfig | None = Field( default=None, description=""" A sampler specification section applies to a group of controls, or to an diff --git a/src/everest/config/control_variable_config.py b/src/everest/config/control_variable_config.py index 8236de12d7c..a7896a9a60f 100644 --- a/src/everest/config/control_variable_config.py +++ b/src/everest/config/control_variable_config.py @@ -1,4 +1,4 @@ -from typing import Any, List, Literal, Optional, Tuple +from typing import Annotated, Any, Literal from pydantic import ( AfterValidator, @@ -9,7 +9,6 @@ PositiveFloat, ) from ropt.enums import VariableType -from typing_extensions import Annotated from everest.config.validation_utils import no_dots_in_string, valid_range @@ -22,39 +21,39 @@ class _ControlVariable(BaseModel): name: Annotated[str, AfterValidator(no_dots_in_string)] = Field( description="Control variable name" ) - control_type: Optional[Literal["real", "integer"]] = Field( + control_type: Literal["real", "integer"] | None = Field( default=None, description=""" The type of control. Set to "integer" for discrete optimization. This may be ignored if the algorithm that is used does not support different control types. """, ) - enabled: Optional[bool] = Field( + enabled: bool | None = Field( default=None, description=""" If `True`, the variable will be optimized, otherwise it will be fixed to the initial value. """, ) - auto_scale: Optional[bool] = Field( + auto_scale: bool | None = Field( default=None, description=""" Can be set to true to re-scale variable from the range defined by [min, max] to the range defined by scaled_range (default [0, 1]) """, ) - scaled_range: Annotated[ - Optional[Tuple[float, float]], AfterValidator(valid_range) - ] = Field( - default=None, - description=""" + scaled_range: Annotated[tuple[float, float] | None, AfterValidator(valid_range)] = ( + Field( + default=None, + description=""" Can be used to set the range of the variable values after scaling (default = [0, 1]). This option has no effect if auto_scale is not set. """, + ) ) - min: Optional[float] = Field( + min: float | None = Field( default=None, description=""" Minimal value allowed for the variable @@ -62,7 +61,7 @@ class _ControlVariable(BaseModel): initial_guess is required to be greater than this value. """, ) - max: Optional[float] = Field( + max: float | None = Field( default=None, description=""" Max value allowed for the variable @@ -70,7 +69,7 @@ class _ControlVariable(BaseModel): initial_guess is required to be less than this value. """, ) - perturbation_magnitude: Optional[PositiveFloat] = Field( + perturbation_magnitude: PositiveFloat | None = Field( default=None, description=""" Specifies the perturbation magnitude for this particular variable. @@ -81,24 +80,24 @@ class _ControlVariable(BaseModel): NOTE: In most cases this should not be configured, and the default value should be used. """, ) - sampler: Optional[SamplerConfig] = Field( + sampler: SamplerConfig | None = Field( default=None, description="The backend used by Everest for sampling points" ) @property - def ropt_control_type(self) -> Optional[VariableType]: + def ropt_control_type(self) -> VariableType | None: return VariableType[self.control_type.upper()] if self.control_type else None class ControlVariableConfig(_ControlVariable): model_config = ConfigDict(title="variable control") - initial_guess: Optional[float] = Field( + initial_guess: float | None = Field( default=None, description=""" Starting value for the control variable, if given needs to be in the interval [min, max] """, ) - index: Optional[NonNegativeInt] = Field( + index: NonNegativeInt | None = Field( default=None, description=""" Index should be given either for all of the variables or for none of them @@ -122,7 +121,7 @@ def uniqueness(self) -> str: class ControlVariableGuessListConfig(_ControlVariable): initial_guess: Annotated[ - List[float], + list[float], Field( default=None, description="List of Starting values for the control variable", diff --git a/src/everest/config/cvar_config.py b/src/everest/config/cvar_config.py index 98400473cd2..1eda051aae4 100644 --- a/src/everest/config/cvar_config.py +++ b/src/everest/config/cvar_config.py @@ -1,10 +1,8 @@ -from typing import Optional - from pydantic import BaseModel, ConfigDict, Field, model_validator class CVaRConfig(BaseModel): # type: ignore - number_of_realizations: Optional[int] = Field( + number_of_realizations: int | None = Field( default=None, description="""The number of realizations used for CVaR estimation. @@ -13,7 +11,7 @@ class CVaRConfig(BaseModel): # type: ignore This option is exclusive with the **percentile** option. """, ) - percentile: Optional[float] = Field( + percentile: float | None = Field( default=None, ge=0.0, le=1.0, diff --git a/src/everest/config/environment_config.py b/src/everest/config/environment_config.py index 948524058b7..080c7bb09b1 100644 --- a/src/everest/config/environment_config.py +++ b/src/everest/config/environment_config.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field, field_validator @@ -6,16 +6,15 @@ class EnvironmentConfig(BaseModel, extra="forbid"): # type: ignore - simulation_folder: Optional[str] = Field( + simulation_folder: str | None = Field( default="simulation_folder", description="Folder used for simulation by Everest" ) - output_folder: Optional[str] = Field( + output_folder: str | None = Field( default="everest_output", description="Folder for outputs of Everest" ) - log_level: Optional[Literal["debug", "info", "warning", "error", "critical"]] = ( - Field( - default="info", - description="""Defines the verbosity of logs output by Everest. + log_level: Literal["debug", "info", "warning", "error", "critical"] | None = Field( + default="info", + description="""Defines the verbosity of logs output by Everest. The default log level is `info`. All supported log levels are: @@ -34,9 +33,8 @@ class EnvironmentConfig(BaseModel, extra="forbid"): # type: ignore critical: A serious error, indicating that the program itself may be unable to continue running. """, - ) ) - random_seed: Optional[int] = Field( + random_seed: int | None = Field( default=None, description="Random seed (must be positive)" ) diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index ba012223b46..9131ac57337 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -6,11 +6,11 @@ from pathlib import Path from typing import ( TYPE_CHECKING, - Dict, - List, + Annotated, Literal, Optional, Protocol, + Self, no_type_check, ) @@ -24,7 +24,6 @@ model_validator, ) from ruamel.yaml import YAML, YAMLError -from typing_extensions import Annotated, Self from ert.config import ErtConfig from everest.config.control_variable_config import ControlVariableGuessListConfig @@ -102,20 +101,20 @@ class HasName(Protocol): class EverestConfig(BaseModelWithPropertySupport): # type: ignore - controls: Annotated[List[ControlConfig], AfterValidator(unique_items)] = Field( + controls: Annotated[list[ControlConfig], AfterValidator(unique_items)] = Field( description="""Defines a list of controls. Controls should have unique names each control defines a group of control variables """, ) - objective_functions: List[ObjectiveFunctionConfig] = Field( + objective_functions: list[ObjectiveFunctionConfig] = Field( description="List of objective function specifications", ) - optimization: Optional[OptimizationConfig] = Field( + optimization: OptimizationConfig | None = Field( default=OptimizationConfig(), description="Optimizer options", ) - model: Optional[ModelConfig] = Field( + model: ModelConfig | None = Field( default=ModelConfig(), description="Configuration of the Everest model", ) @@ -123,16 +122,16 @@ class EverestConfig(BaseModelWithPropertySupport): # type: ignore # It IS required but is currently used in a non-required manner by tests # Thus, it is to be made explicitly required as the other logic # is being rewritten - environment: Optional[EnvironmentConfig] = Field( + environment: EnvironmentConfig | None = Field( default=EnvironmentConfig(), description="The environment of Everest, specifies which folders are used " "for simulation and output, as well as the level of detail in Everest-logs", ) - wells: List[WellConfig] = Field( + wells: list[WellConfig] = Field( default_factory=list, description="A list of well configurations, all with unique names.", ) - definitions: Optional[dict] = Field( + definitions: dict | None = Field( default_factory=dict, description="""Section for specifying variables. @@ -157,25 +156,25 @@ class EverestConfig(BaseModelWithPropertySupport): # type: ignore | ... """, ) - input_constraints: Optional[List[InputConstraintConfig]] = Field( + input_constraints: list[InputConstraintConfig] | None = Field( default=None, description="List of input constraints" ) - output_constraints: Optional[List[OutputConstraintConfig]] = Field( + output_constraints: list[OutputConstraintConfig] | None = Field( default=None, description="A list of output constraints with unique names." ) - install_jobs: Optional[List[InstallJobConfig]] = Field( + install_jobs: list[InstallJobConfig] | None = Field( default=None, description="A list of jobs to install" ) - install_workflow_jobs: Optional[List[InstallJobConfig]] = Field( + install_workflow_jobs: list[InstallJobConfig] | None = Field( default=None, description="A list of workflow jobs to install" ) - install_data: Optional[List[InstallDataConfig]] = Field( + install_data: list[InstallDataConfig] | None = Field( default=None, description="""A list of install data elements from the install_data config section. Each item marks what folders or paths need to be copied or linked in order for the evaluation jobs to run.""", ) - install_templates: Optional[List[InstallTemplateConfig]] = Field( + install_templates: list[InstallTemplateConfig] | None = Field( default=None, description="""Allow the user to define the workflow establishing the model chain for the purpose of sensitivity analysis, enabling the relationship @@ -183,7 +182,7 @@ class EverestConfig(BaseModelWithPropertySupport): # type: ignore evaluated. """, ) - server: Optional[ServerConfig] = Field( + server: ServerConfig | None = Field( default=None, description="""Defines Everest server settings, i.e., which queue system, queue name and queue options are used for the everest server. @@ -201,16 +200,16 @@ class EverestConfig(BaseModelWithPropertySupport): # type: ignore requirements of the forward models. """, ) - simulator: Optional[SimulatorConfig] = Field( + simulator: SimulatorConfig | None = Field( default_factory=SimulatorConfig, description="Simulation settings" ) - forward_model: Optional[List[str]] = Field( + forward_model: list[str] | None = Field( default=None, description="List of jobs to run" ) - workflows: Optional[WorkflowConfig] = Field( + workflows: WorkflowConfig | None = Field( default=None, description="Workflows to run during optimization" ) - export: Optional[ExportConfig] = Field( + export: ExportConfig | None = Field( default=None, description="Settings to control the exports of a optimization run by everest.", ) @@ -248,7 +247,7 @@ def validate_install_job_sources(self) -> Self: # pylint: disable=E0213 continue exec_path = None valid_jobfile = True - with open(abs_config_path, "r", encoding="utf-8") as jobfile: + with open(abs_config_path, encoding="utf-8") as jobfile: for line in jobfile: if not line.startswith("EXECUTABLE"): continue @@ -515,7 +514,7 @@ def validate_that_environment_sim_folder_is_writeable(self) -> Self: @field_validator("wells") @no_type_check @classmethod - def validate_unique_well_names(cls, wells: List[WellConfig]): + def validate_unique_well_names(cls, wells: list[WellConfig]): check_for_duplicate_names([w.name for w in wells], "well", "name") return wells @@ -524,7 +523,7 @@ def validate_unique_well_names(cls, wells: List[WellConfig]): @no_type_check @classmethod def validate_unique_output_constraint_names( - cls, output_constraints: List[OutputConstraintConfig] + cls, output_constraints: list[OutputConstraintConfig] ): check_for_duplicate_names( [c.name for c in output_constraints], "output constraint", "name" @@ -597,14 +596,14 @@ def logging_level( env.log_level = level # pylint:disable = E0237 @property - def config_directory(self) -> Optional[str]: + def config_directory(self) -> str | None: if self.config_path is not None: return str(self.config_path.parent) return None @property - def config_file(self) -> Optional[str]: + def config_file(self) -> str | None: if self.config_path is not None: return self.config_path.name return None @@ -628,7 +627,7 @@ def output_dir(self) -> str: return os.path.join(cfgdir, path) @property - def simulation_dir(self) -> Optional[str]: + def simulation_dir(self) -> str | None: assert self.environment is not None path = self.environment.simulation_folder @@ -678,7 +677,7 @@ def result_names(self): return objectives_names + constraint_names @property - def function_aliases(self) -> Dict[str, str]: + def function_aliases(self) -> dict[str, str]: aliases = { objective.name: objective.alias for objective in self.objective_functions @@ -739,7 +738,7 @@ def with_defaults(cls, **kwargs): return EverestConfig.model_validate({**defaults, **kwargs}) @staticmethod - def lint_config_dict(config: dict) -> List["ErrorDetails"]: + def lint_config_dict(config: dict) -> list["ErrorDetails"]: try: EverestConfig.model_validate(config) return [] @@ -782,7 +781,7 @@ def load_file_with_argparser( f"{format_errors(e)}" ) - def dump(self, fname: Optional[str] = None) -> Optional[str]: + def dump(self, fname: str | None = None) -> str | None: """Write a config dict to file or return it if fname is None.""" stripped_conf = self.to_dict() diff --git a/src/everest/config/export_config.py b/src/everest/config/export_config.py index 60f612b4b54..c6eedf4e19d 100644 --- a/src/everest/config/export_config.py +++ b/src/everest/config/export_config.py @@ -1,35 +1,33 @@ -from typing import List, Optional - from pydantic import BaseModel, Field, field_validator from everest.config.validation_utils import check_writable_filepath class ExportConfig(BaseModel, extra="forbid"): # type: ignore - csv_output_filepath: Optional[str] = Field( + csv_output_filepath: str | None = Field( default=None, description="""Specifies which file to write the export to. Defaults to .csv in output folder.""", ) - discard_gradient: Optional[bool] = Field( + discard_gradient: bool | None = Field( default=None, description="If set to True, Everest export will not contain " "gradient simulation data.", ) - discard_rejected: Optional[bool] = Field( + discard_rejected: bool | None = Field( default=None, description="""If set to True, Everest export will contain only simulations that have the increase_merit flag set to true.""", ) - keywords: Optional[List[str]] = Field( + keywords: list[str] | None = Field( default=None, description="List of eclipse keywords to be exported into csv.", ) - batches: Optional[List[int]] = Field( + batches: list[int] | None = Field( default=None, description="list of batches to be exported, default is all batches.", ) - skip_export: Optional[bool] = Field( + skip_export: bool | None = Field( default=None, description="""set to True if export should not be run after the optimization case. diff --git a/src/everest/config/has_ert_queue_options.py b/src/everest/config/has_ert_queue_options.py index c427391455a..162ddc81b38 100644 --- a/src/everest/config/has_ert_queue_options.py +++ b/src/everest/config/has_ert_queue_options.py @@ -1,10 +1,10 @@ -from typing import Any, List, Tuple +from typing import Any class HasErtQueueOptions: def extract_ert_queue_options( - self, queue_system: str, everest_to_ert_key_tuples: List[Tuple[str, str]] - ) -> List[Tuple[str, str, Any]]: + self, queue_system: str, everest_to_ert_key_tuples: list[tuple[str, str]] + ) -> list[tuple[str, str, Any]]: result = [] for ever_key, ert_key in everest_to_ert_key_tuples: attribute = getattr(self, ever_key) diff --git a/src/everest/config/input_constraint_config.py b/src/everest/config/input_constraint_config.py index 3a5f82809c5..1bf9dc1fa38 100644 --- a/src/everest/config/input_constraint_config.py +++ b/src/everest/config/input_constraint_config.py @@ -1,10 +1,8 @@ -from typing import Dict, Optional - from pydantic import BaseModel, Field, field_validator class InputConstraintConfig(BaseModel, extra="forbid"): # type: ignore - weights: Dict[str, float] = Field( + weights: dict[str, float] = Field( description="""**Example** If we are trying to constrain only one control (i.e the z control) value: | input_constraints: @@ -18,7 +16,7 @@ class InputConstraintConfig(BaseModel, extra="forbid"): # type: ignore `x-0 * 0 + y-1 * 0 + z-2 * 1 > 0.2` """, ) - target: Optional[float] = Field( + target: float | None = Field( default=None, description="""**Example** | input_constraints: @@ -32,7 +30,7 @@ class InputConstraintConfig(BaseModel, extra="forbid"): # type: ignore `x-0 * 1 + y-1 * 2 + z-2 * 3 = 4` """, ) - lower_bound: Optional[float] = Field( + lower_bound: float | None = Field( default=None, description="""**Example** | input_constraints: @@ -47,7 +45,7 @@ class InputConstraintConfig(BaseModel, extra="forbid"): # type: ignore `x-0 * 1 + y-1 * 2 + z-2 * 3 >= 4` """, ) - upper_bound: Optional[float] = Field( + upper_bound: float | None = Field( default=None, description="""**Example** | input_constraints: @@ -65,7 +63,7 @@ class InputConstraintConfig(BaseModel, extra="forbid"): # type: ignore @field_validator("weights") @classmethod # pylint: disable=E0213 - def validate_weights_not_empty(cls, weights: Dict[str, float]) -> Dict[str, float]: + def validate_weights_not_empty(cls, weights: dict[str, float]) -> dict[str, float]: if weights is None or weights == {}: raise ValueError("Input weight data required for input constraints") return weights diff --git a/src/everest/config/install_data_config.py b/src/everest/config/install_data_config.py index 065045a7a99..be8773004fd 100644 --- a/src/everest/config/install_data_config.py +++ b/src/everest/config/install_data_config.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional from pydantic import BaseModel, Field, field_validator, model_validator @@ -16,7 +15,7 @@ class InstallDataConfig(BaseModel, extra="forbid"): # type: ignore Relative path to place the copy or link for the given source. """ ) # path - link: Optional[bool] = Field( + link: bool | None = Field( default=None, description=""" If set to true will create a link to the given source at the given target, diff --git a/src/everest/config/install_template_config.py b/src/everest/config/install_template_config.py index b5085403b10..d41aa7a288e 100644 --- a/src/everest/config/install_template_config.py +++ b/src/everest/config/install_template_config.py @@ -1,9 +1,7 @@ -from typing import Optional - from pydantic import BaseModel, Field class InstallTemplateConfig(BaseModel, extra="forbid"): # type: ignore template: str = Field() # existing file output_file: str = Field() # path - extra_data: Optional[str] = Field(default=None) # path + extra_data: str | None = Field(default=None) # path diff --git a/src/everest/config/model_config.py b/src/everest/config/model_config.py index 3d70542fd02..9fe95affe02 100644 --- a/src/everest/config/model_config.py +++ b/src/everest/config/model_config.py @@ -1,25 +1,23 @@ -from typing import List, Optional - from pydantic import BaseModel, Field, NonNegativeInt, model_validator from ert.config import ConfigWarning class ModelConfig(BaseModel, extra="forbid"): # type: ignore - realizations: List[NonNegativeInt] = Field( + realizations: list[NonNegativeInt] = Field( default_factory=lambda: [], description="""List of realizations to use in optimization ensemble. Typically, this is a list [0, 1, ..., n-1] of all realizations in the ensemble.""", ) - data_file: Optional[str] = Field( + data_file: str | None = Field( default=None, description="""Path to the eclipse data file used for optimization. The path can contain r{{geo_id}}. NOTE: Without a data file no well or group specific summary data will be exported.""", ) - realizations_weights: Optional[List[float]] = Field( + realizations_weights: list[float] | None = Field( default=None, description="""List of weights, one per realization. diff --git a/src/everest/config/objective_function_config.py b/src/everest/config/objective_function_config.py index cab5439c331..72f75c2551b 100644 --- a/src/everest/config/objective_function_config.py +++ b/src/everest/config/objective_function_config.py @@ -1,11 +1,9 @@ -from typing import Optional - from pydantic import BaseModel, Field, PositiveFloat, field_validator class ObjectiveFunctionConfig(BaseModel, extra="forbid"): # type: ignore name: str = Field() - alias: Optional[str] = Field( + alias: str | None = Field( default=None, description=""" alias can be set to the name of another objective function, directing everest @@ -17,7 +15,7 @@ class ObjectiveFunctionConfig(BaseModel, extra="forbid"): # type: ignore sure that the standard deviation is calculated over the values of that objective. """, ) - weight: Optional[PositiveFloat] = Field( + weight: PositiveFloat | None = Field( default=None, description=""" weight determines the importance of an objective function relative to the other @@ -28,7 +26,7 @@ class ObjectiveFunctionConfig(BaseModel, extra="forbid"): # type: ignore used in the optimization process. """, ) - normalization: Optional[float] = Field( + normalization: float | None = Field( default=None, description=""" normalization is a multiplication factor defined per objective function. @@ -40,7 +38,7 @@ class ObjectiveFunctionConfig(BaseModel, extra="forbid"): # type: ignore the weighted sum that Everest tries to optimize. """, ) - auto_normalize: Optional[bool] = Field( + auto_normalize: bool | None = Field( default=None, description=""" auto_normalize can be set to true to automatically @@ -49,7 +47,7 @@ class ObjectiveFunctionConfig(BaseModel, extra="forbid"): # type: ignore If normalization is also set, the automatic value is multiplied by its value. """, ) - type: Optional[str] = Field( + type: str | None = Field( default=None, description=""" type can be set to the name of a method that should be applied to calculate a diff --git a/src/everest/config/optimization_config.py b/src/everest/config/optimization_config.py index ed39835101a..b82df79e55b 100644 --- a/src/everest/config/optimization_config.py +++ b/src/everest/config/optimization_config.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import BaseModel, Field, model_validator @@ -7,13 +7,13 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore - algorithm: Optional[str] = Field( + algorithm: str | None = Field( default="default", description="""Algorithm used by Everest. Defaults to optpp_q_newton, a quasi-Newton algorithm in Dakota's OPT PP library. """, ) - convergence_tolerance: Optional[float] = Field( + convergence_tolerance: float | None = Field( default=None, description="""Defines the threshold value on relative change in the objective function that indicates convergence. @@ -32,7 +32,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore (From the Dakota Manual.)""", ) - backend: Optional[str] = Field( + backend: str | None = Field( default="dakota", description="""The optimization backend used. Defaults to "dakota". @@ -41,7 +41,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore backend is optional, and will only be available if SciPy is installed on the system.""", ) - backend_options: Optional[Dict[str, Any]] = Field( + backend_options: dict[str, Any] | None = Field( default=None, description="""Dict of optional parameters for the optimizer backend. This dict of values is passed unchanged to the selected algorithm in the backend. @@ -50,7 +50,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore list of strings rather than a dictionary. For setting Dakota backend options, see the 'option' keyword.""", ) - constraint_tolerance: Optional[float] = Field( + constraint_tolerance: float | None = Field( default=None, description="""Determines the maximum allowable value of infeasibility that any constraint in an optimization problem may possess and @@ -65,7 +65,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore (From the Dakota Manual.)""", ) - cvar: Optional[CVaRConfig] = Field( + cvar: CVaRConfig | None = Field( default=None, description="""Directs the optimizer to use CVaR estimation. @@ -78,7 +78,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore which are mutually exclusive. """, ) - max_batch_num: Optional[int] = Field( + max_batch_num: int | None = Field( default=None, gt=0, description="""Limits the number of batches of simulations @@ -86,7 +86,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore When max_batch_num is specified and the current batch index is greater than max_batch_num an exception is raised.""", ) - max_function_evaluations: Optional[int] = Field( + max_function_evaluations: int | None = Field( default=None, gt=0, description="""Limits the maximum number of function evaluations. @@ -97,7 +97,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore See max_iterations for a description. """, ) - max_iterations: Optional[int] = Field( + max_iterations: int | None = Field( default=None, gt=0, description="""Limits the maximum number of iterations. @@ -106,7 +106,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore a complete accepted batch (i.e., a batch that provides an improvement in the objective function while satisfying all constraints).""", ) - min_pert_success: Optional[int] = Field( + min_pert_success: int | None = Field( default=None, gt=0, description="""specifies the minimum number of successfully completed @@ -124,7 +124,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore `perturbation_num: 1` the maximum allowed value is the number of realizations specified by realizations instead.""", ) - min_realizations_success: Optional[int] = Field( + min_realizations_success: int | None = Field( default=None, ge=0, description="""Minimum number of realizations @@ -150,7 +150,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore the value to be equal to one. """, ) - options: Optional[List[str]] = Field( + options: list[str] | None = Field( default=None, description="""specifies non-validated, optional passthrough parameters for the optimizer @@ -160,7 +160,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore | - retry_if_fail | - classical_search 1""", ) - perturbation_num: Optional[int] = Field( + perturbation_num: int | None = Field( default=None, gt=0, description="""The number of perturbed control vectors per realization. @@ -168,7 +168,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore The number of simulation runs used for estimating the gradient is equal to the the product of perturbation_num and model.realizations.""", ) - speculative: Optional[bool] = Field( + speculative: bool | None = Field( default=None, description="""specifies whether to enable speculative computation. @@ -182,7 +182,7 @@ class OptimizationConfig(BaseModel, extra="forbid"): # type: ignore parallelism in the gradient calculations can be exploited by Dakota (it will be ignored for vendor numerical gradients). (From the Dakota Manual.)""", ) - parallel: Optional[bool] = Field( + parallel: bool | None = Field( default=None, description="""whether to allow parallel function evaluation. diff --git a/src/everest/config/output_constraint_config.py b/src/everest/config/output_constraint_config.py index a0b15298e00..a6b0e74c683 100644 --- a/src/everest/config/output_constraint_config.py +++ b/src/everest/config/output_constraint_config.py @@ -1,11 +1,9 @@ -from typing import Optional - from pydantic import BaseModel, Field, model_validator class OutputConstraintConfig(BaseModel, extra="forbid"): # type: ignore name: str = Field(description="The unique name of the output constraint.") - target: Optional[float] = Field( + target: float | None = Field( default=None, description="""Defines the equality constraint @@ -16,14 +14,14 @@ class OutputConstraintConfig(BaseModel, extra="forbid"): # type: ignore """, ) - auto_scale: Optional[bool] = Field( + auto_scale: bool | None = Field( default=None, description="""If set to true, Everest will automatically determine the scaling factor from the constraint value in batch 0. If scale is also set, the automatic value is multiplied by its value.""", ) - lower_bound: Optional[float] = Field( + lower_bound: float | None = Field( default=None, description="""Defines the lower bound (greater than or equal) constraint @@ -34,7 +32,7 @@ class OutputConstraintConfig(BaseModel, extra="forbid"): # type: ignore the scale (scale). """, ) - upper_bound: Optional[float] = Field( + upper_bound: float | None = Field( default=None, description="""Defines the upper bound (less than or equal) constraint: @@ -43,7 +41,7 @@ class OutputConstraintConfig(BaseModel, extra="forbid"): # type: ignore where b is the upper bound, f is a function of the control vector x, and c is the scale (scale).""", ) - scale: Optional[float] = Field( + scale: float | None = Field( default=None, description="""Scaling of constraints (scale). diff --git a/src/everest/config/sampler_config.py b/src/everest/config/sampler_config.py index 5fbf454abb0..5f3c3a5a4e0 100644 --- a/src/everest/config/sampler_config.py +++ b/src/everest/config/sampler_config.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -15,7 +15,7 @@ class SamplerConfig(BaseModel): # type: ignore """, ) - options: Optional[Dict[str, Any]] = Field( + options: dict[str, Any] | None = Field( default=None, alias="backend_options", description=""" @@ -30,7 +30,7 @@ class SamplerConfig(BaseModel): # type: ignore description="""The sampling method or distribution used by the sampler backend. """, ) - shared: Optional[bool] = Field( + shared: bool | None = Field( default=None, description="""Whether to share perturbations between realizations. """, diff --git a/src/everest/config/server_config.py b/src/everest/config/server_config.py index 50fc7911634..bde3db4e13d 100644 --- a/src/everest/config/server_config.py +++ b/src/everest/config/server_config.py @@ -1,6 +1,6 @@ import json import os -from typing import Literal, Optional, Tuple +from typing import Literal from pydantic import BaseModel, ConfigDict, Field @@ -15,7 +15,7 @@ class ServerConfig(BaseModel, HasErtQueueOptions): # type: ignore - name: Optional[str] = Field( + name: str | None = Field( None, description="""Specifies which queue to use. @@ -27,23 +27,23 @@ class ServerConfig(BaseModel, HasErtQueueOptions): # type: ignore as RMS and Eclipse. """, ) # Corresponds to queue name - exclude_host: Optional[str] = Field( + exclude_host: str | None = Field( "", description="""Comma separated list of nodes that should be excluded from the slurm run""", ) - include_host: Optional[str] = Field( + include_host: str | None = Field( "", description="""Comma separated list of nodes that should be included in the slurm run""", ) - options: Optional[str] = Field( + options: str | None = Field( None, description="""Used to specify options to LSF. Examples to set memory requirement is: * rusage[mem=1000]""", ) - queue_system: Optional[Literal["lsf", "local", "slurm"]] = Field( + queue_system: Literal["lsf", "local", "slurm"] | None = Field( None, description="Defines which queue system the everest server runs on.", ) @@ -62,7 +62,7 @@ def get_server_url(output_dir: str) -> str: return f"https://{server_info['host']}:{server_info['port']}" @staticmethod - def get_server_context(output_dir: str) -> Tuple[str, str, Tuple[str, str]]: + def get_server_context(output_dir: str) -> tuple[str, str, tuple[str, str]]: """Returns a tuple with - url of the server - path to the .cert file @@ -80,7 +80,7 @@ def get_server_info(output_dir: str) -> dict: """Load server information from the hostfile""" host_file_path = ServerConfig.get_hostfile_path(output_dir) try: - with open(host_file_path, "r", encoding="utf-8") as f: + with open(host_file_path, encoding="utf-8") as f: json_string = f.read() data = json.loads(json_string) diff --git a/src/everest/config/simulator_config.py b/src/everest/config/simulator_config.py index b90238eb321..27da247de29 100644 --- a/src/everest/config/simulator_config.py +++ b/src/everest/config/simulator_config.py @@ -1,5 +1,5 @@ import warnings -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt, field_validator @@ -7,10 +7,8 @@ class SimulatorConfig(BaseModel, HasErtQueueOptions, extra="forbid"): # type: ignore - name: Optional[str] = Field( - default=None, description="Specifies which queue to use" - ) - cores: Optional[PositiveInt] = Field( + name: str | None = Field(default=None, description="Specifies which queue to use") + cores: PositiveInt | None = Field( default=None, description="""Defines the number of simultaneously running forward models. @@ -21,7 +19,7 @@ class SimulatorConfig(BaseModel, HasErtQueueOptions, extra="forbid"): # type: i This number is specified in Ert as MAX_RUNNING. """, ) - cores_per_node: Optional[PositiveInt] = Field( + cores_per_node: PositiveInt | None = Field( default=None, description="""defines the number of CPUs when running the forward models. This can for example be used in conjunction with the Eclipse @@ -30,47 +28,47 @@ class SimulatorConfig(BaseModel, HasErtQueueOptions, extra="forbid"): # type: i This number is specified in Ert as NUM_CPU.""", ) - delete_run_path: Optional[bool] = Field( + delete_run_path: bool | None = Field( default=None, description="Whether the batch folder for a successful simulation " "needs to be deleted.", ) - exclude_host: Optional[str] = Field( + exclude_host: str | None = Field( "", description="""Comma separated list of nodes that should be excluded from the slurm run.""", ) - include_host: Optional[str] = Field( + include_host: str | None = Field( "", description="""Comma separated list of nodes that should be included in the slurm run""", ) - max_memory: Optional[str] = Field( + max_memory: str | None = Field( default=None, description="Maximum memory usage for a slurm job.", ) - max_memory_cpu: Optional[str] = Field( + max_memory_cpu: str | None = Field( default=None, description="Maximum memory usage per cpu for a slurm job.", ) - max_runtime: Optional[NonNegativeInt] = Field( + max_runtime: NonNegativeInt | None = Field( default=None, description="""Maximum allowed running time of a forward model. When set, a job is only allowed to run for max_runtime seconds. A value of 0 means unlimited runtime. """, ) - options: Optional[str] = Field( + options: str | None = Field( default=None, description="""Used to specify options to LSF. Examples to set memory requirement is: * rusage[mem=1000]""", ) - queue_system: Optional[Literal["lsf", "local", "slurm", "torque"]] = Field( + queue_system: Literal["lsf", "local", "slurm", "torque"] | None = Field( default="local", description="Defines which queue system the everest server runs on.", ) - resubmit_limit: Optional[NonNegativeInt] = Field( + resubmit_limit: NonNegativeInt | None = Field( default=None, description=""" Defines how many times should the queue system retry a forward model. @@ -81,35 +79,35 @@ class SimulatorConfig(BaseModel, HasErtQueueOptions, extra="forbid"): # type: i resumbit_limit defines the number of times we will resubmit a failing forward model. If not specified, a default value of 1 will be used.""", ) - sbatch: Optional[str] = Field( + sbatch: str | None = Field( default=None, description="sbatch executable to be used by the slurm queue interface.", ) - scancel: Optional[str] = Field( + scancel: str | None = Field( default=None, description="scancel executable to be used by the slurm queue interface.", ) - scontrol: Optional[str] = Field( + scontrol: str | None = Field( default=None, description="scontrol executable to be used by the slurm queue interface.", ) - sacct: Optional[str] = Field( + sacct: str | None = Field( default=None, description="sacct executable to be used by the slurm queue interface.", ) - squeue: Optional[str] = Field( + squeue: str | None = Field( default=None, description="squeue executable to be used by the slurm queue interface.", ) - server: Optional[str] = Field( + server: str | None = Field( default=None, description="Name of LSF server to use. This option is deprecated and no longer required", ) - slurm_timeout: Optional[int] = Field( + slurm_timeout: int | None = Field( default=None, description="Timeout for cached status used by the slurm queue interface", ) - squeue_timeout: Optional[int] = Field( + squeue_timeout: int | None = Field( default=None, description="Timeout for cached status used by the slurm queue interface.", ) @@ -126,40 +124,40 @@ class SimulatorConfig(BaseModel, HasErtQueueOptions, extra="forbid"): # type: i the most common use of a standard optimization with a continuous optimizer.""", ) - qsub_cmd: Optional[str] = Field(default="qsub", description="The submit command") - qstat_cmd: Optional[str] = Field(default="qstat", description="The query command") - qdel_cmd: Optional[str] = Field(default="qdel", description="The kill command") - qstat_options: Optional[str] = Field( + qsub_cmd: str | None = Field(default="qsub", description="The submit command") + qstat_cmd: str | None = Field(default="qstat", description="The query command") + qdel_cmd: str | None = Field(default="qdel", description="The kill command") + qstat_options: str | None = Field( default="-x", description="Options to be supplied to the qstat command. This defaults to -x, which tells the qstat command to include exited processes.", ) - cluster_label: Optional[str] = Field( + cluster_label: str | None = Field( default=None, description="The name of the cluster you are running simulations in.", ) - memory_per_job: Optional[str] = Field( + memory_per_job: str | None = Field( default=None, description="""You can specify the amount of memory you will need for running your job. This will ensure that not too many jobs will run on a single shared memory node at once, possibly crashing the compute node if it runs out of memory. You can get an indication of the memory requirement by watching the course of a local run using the htop utility. Whether you should set the peak memory usage as your requirement or a lower figure depends on how simultaneously each job will run. The option to be supplied will be used as a string in the qsub argument. You must specify the unit, either gb or mb. """, ) - keep_qsub_output: Optional[int] = Field( + keep_qsub_output: int | None = Field( default=0, description="Set to 1 to keep error messages from qsub. Usually only to be used if somethign is seriously wrong with the queue environment/setup.", ) - submit_sleep: Optional[float] = Field( + submit_sleep: float | None = Field( default=0.5, description="To avoid stressing the TORQUE/PBS system you can instruct the driver to sleep for every submit request. The argument to the SUBMIT_SLEEP is the number of seconds to sleep for every submit, which can be a fraction like 0.5", ) - queue_query_timeout: Optional[int] = Field( + queue_query_timeout: int | None = Field( default=126, description=""" The driver allows the backend TORQUE/PBS system to be flaky, i.e. it may intermittently not respond and give error messages when submitting jobs or asking for job statuses. The timeout (in seconds) determines how long ERT will wait before it will give up. Applies to job submission (qsub) and job status queries (qstat). Default is 126 seconds. ERT will do exponential sleeps, starting at 2 seconds, and the provided timeout is a maximum. Let the timeout be sums of series like 2+4+8+16+32+64 in order to be explicit about the number of retries. Set to zero to disallow flakyness, setting it to 2 will allow for one re-attempt, and 6 will give two re-attempts. Example allowing six retries: """, ) - project_code: Optional[str] = Field( + project_code: str | None = Field( default=None, description="String identifier used to map hardware resource usage to a project or account. The project or account does not have to exist.", ) diff --git a/src/everest/config/validation_utils.py b/src/everest/config/validation_utils.py index 62b2d3132a3..05c150b70e6 100644 --- a/src/everest/config/validation_utils.py +++ b/src/everest/config/validation_utils.py @@ -2,9 +2,10 @@ import os import tempfile from collections import Counter +from collections.abc import Sequence from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar from pydantic import BaseModel, ValidationError @@ -26,7 +27,7 @@ class InstallDataContext: - def __init__(self, install_data: List[InstallDataConfig], config_path: Path): + def __init__(self, install_data: list[InstallDataConfig], config_path: Path): self._install_data = install_data or [] self._config_dir = str(config_path.parent) self._cwd = os.getcwd() @@ -40,7 +41,7 @@ def __enter__(self): os.chdir(self._temp_dir.name) return self - def _set_symlink(self, source: str, target: str, realization: Optional[int]): + def _set_symlink(self, source: str, target: str, realization: int | None): if realization is not None: source = source.replace("", str(realization)) target = target.replace("", str(realization)) @@ -64,14 +65,14 @@ def __exit__(self, exc_type, exc_value, exc_tb): def control_variables_validation( name: str, - _min: Optional[float], - _max: Optional[float], - initial_guess: Union[float, List[float], None], -) -> List[str]: + min_: float | None, + max_: float | None, + initial_guess: float | list[float] | None, +) -> list[str]: error = [] - if _min is None: + if min_ is None: error.append(_VARIABLE_ERROR_MESSAGE.format(name=name, variable_type="min")) - if _max is None: + if max_ is None: error.append(_VARIABLE_ERROR_MESSAGE.format(name=name, variable_type="max")) if initial_guess is None: error.append( @@ -80,16 +81,16 @@ def control_variables_validation( if isinstance(initial_guess, float): initial_guess = [initial_guess] if ( - _min is not None - and _max is not None + min_ is not None + and max_ is not None and ( msg := ", ".join( - str(guess) for guess in initial_guess or [] if not _min <= guess <= _max + str(guess) for guess in initial_guess or [] if not min_ <= guess <= max_ ) ) ): error.append( - f"Variable {name} must respect {_min} <= initial_guess <= {_max}: {msg}" + f"Variable {name} must respect {min_} <= initial_guess <= {max_}: {msg}" ) return error @@ -141,7 +142,7 @@ def unique_items(items: Sequence[T]) -> Sequence[T]: return items -def valid_range(range_value: Tuple[float, float]): +def valid_range(range_value: tuple[float, float]): if range_value[0] >= range_value[1]: raise ValueError("scaled_range must be a valid range [a, b], where a < b.") return range_value @@ -172,7 +173,7 @@ def check_writable_filepath(path: str): raise ValueError(f"User does not have write access to {path}") -def check_for_duplicate_names(names: List[str], item_name: str, key: str = "item"): +def check_for_duplicate_names(names: list[str], item_name: str, key: str = "item"): if len(set(names)) != len(names): histogram = {k: names.count(k) for k in set(names) if names.count(k) > 1} occurrences_str = ", ".join( @@ -191,14 +192,14 @@ def as_abs_path(path: str, config_dir: str) -> str: return os.path.realpath(os.path.join(config_dir, path)) -def expand_geo_id_paths(path_source: str, realizations: List[int]): +def expand_geo_id_paths(path_source: str, realizations: list[int]): if "" in path_source: return [path_source.replace("", str(r)) for r in realizations] return [path_source] def check_path_exists( - path_source: str, config_path: Optional[Path], realizations: List[int] + path_source: str, config_path: Path | None, realizations: list[int] ): """Check if the given path exists. If the given path contains or GEO_ID they will be expanded and all instances of expanded paths need to exist. @@ -265,7 +266,7 @@ def format_errors(error: ValidationError) -> str: def validate_forward_model_configs( - forward_model: Optional[List[str]], install_jobs: Optional[List[InstallJobConfig]] + forward_model: list[str] | None, install_jobs: list[InstallJobConfig] | None ): if not forward_model: return diff --git a/src/everest/config/well_config.py b/src/everest/config/well_config.py index 87872e1c6b8..32fbfa8c20e 100644 --- a/src/everest/config/well_config.py +++ b/src/everest/config/well_config.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -8,7 +7,7 @@ class WellConfig(BaseModel): name: str = Field(description="The unique name of the well") - drill_date: Optional[str] = Field( + drill_date: str | None = Field( None, description="""Ideal date to drill a well. @@ -16,7 +15,7 @@ class WellConfig(BaseModel): consider this as the earliest possible drill date. """, ) - drill_time: Optional[float] = Field( + drill_time: float | None = Field( None, description="""specifies the time it takes to drill the well under consideration.""", diff --git a/src/everest/config/workflow_config.py b/src/everest/config/workflow_config.py index bd5dadf3048..946005d02a7 100644 --- a/src/everest/config/workflow_config.py +++ b/src/everest/config/workflow_config.py @@ -1,14 +1,12 @@ -from typing import List, Optional - from pydantic import BaseModel, ConfigDict, Field class WorkflowConfig(BaseModel): # type: ignore - pre_simulation: Optional[List[str]] = Field( + pre_simulation: list[str] | None = Field( default=None, description="List of workflow jobs triggered pre-simulation", ) - post_simulation: Optional[List[str]] = Field( + post_simulation: list[str] | None = Field( default=None, description="List of workflow jobs triggered post-simulation", ) diff --git a/src/everest/config_file_loader.py b/src/everest/config_file_loader.py index 4aa49fea9d6..15f1d29ed6d 100644 --- a/src/everest/config_file_loader.py +++ b/src/everest/config_file_loader.py @@ -2,7 +2,7 @@ import logging import os -from typing import Any, Dict, List, Optional +from typing import Any import jinja2 from ruamel.yaml import YAML, YAMLError @@ -28,9 +28,9 @@ } -def load_yaml(file_name: str) -> Optional[Dict[str, Any]]: - with open(file_name, "r", encoding="utf-8") as input_file: - input_data: List[str] = input_file.readlines() +def load_yaml(file_name: str) -> dict[str, Any] | None: + with open(file_name, encoding="utf-8") as input_file: + input_data: list[str] = input_file.readlines() try: yaml = YAML() yaml.preserve_quotes = True @@ -85,7 +85,7 @@ def _os(): """ - class Os(object): + class Os: pass x = Os() @@ -114,7 +114,7 @@ def _render_definitions(definitions, jinja_env): ) -def yaml_file_to_substituted_config_dict(config_path: str) -> Dict[str, Any]: +def yaml_file_to_substituted_config_dict(config_path: str) -> dict[str, Any]: configuration = load_yaml(config_path) definitions = _get_definitions( @@ -122,7 +122,7 @@ def yaml_file_to_substituted_config_dict(config_path: str) -> Dict[str, Any]: configpath=os.path.dirname(os.path.abspath(config_path)), ) definitions["os"] = _os() # update definitions with os namespace - with open(config_path, "r", encoding="utf-8") as f: + with open(config_path, encoding="utf-8") as f: txt = "".join(f.readlines()) jenv = jinja2.Environment( block_start_string=BLOCK_START_STRING, @@ -137,7 +137,7 @@ def yaml_file_to_substituted_config_dict(config_path: str) -> Dict[str, Any]: # Load the config with definitions again as yaml yaml = YAML(typ="safe", pure=True).load(config) - if not isinstance(yaml, Dict): + if not isinstance(yaml, dict): yaml = {} # Inject config path diff --git a/src/everest/detached/__init__.py b/src/everest/detached/__init__.py index d5072269c89..503ea932928 100644 --- a/src/everest/detached/__init__.py +++ b/src/everest/detached/__init__.py @@ -6,9 +6,10 @@ import re import time import traceback +from collections.abc import Mapping from enum import Enum from pathlib import Path -from typing import Literal, Mapping, Optional, Tuple +from typing import Literal import requests from seba_sqlite.exceptions import ObjectNotFoundError @@ -72,7 +73,7 @@ async def start_server(config: EverestConfig, debug: bool = False) -> Driver: return driver -def stop_server(server_context: Tuple[str, str, Tuple[str, str]], retries: int = 5): +def stop_server(server_context: tuple[str, str, tuple[str, str]], retries: int = 5): """ Stop server if found and it is running. """ @@ -95,7 +96,7 @@ def stop_server(server_context: Tuple[str, str, Tuple[str, str]], retries: int = def extract_errors_from_file(path: str): - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: content = f.read() return re.findall(r"(Error \w+.*)", content) @@ -163,7 +164,7 @@ def get_opt_status(output_folder): } -def wait_for_server_to_stop(server_context: Tuple[str, str, Tuple[str, str]], timeout): +def wait_for_server_to_stop(server_context: tuple[str, str, tuple[str, str]], timeout): """ Checks everest server has stoped _HTTP_REQUEST_RETRY times. Waits progressively longer between each check. @@ -183,7 +184,7 @@ def wait_for_server_to_stop(server_context: Tuple[str, str, Tuple[str, str]], ti raise Exception("Failed to stop server within configured timeout.") -def server_is_running(url: str, cert: str, auth: Tuple[str, str]): +def server_is_running(url: str, cert: str, auth: tuple[str, str]): try: response = requests.get( url, @@ -200,7 +201,7 @@ def server_is_running(url: str, cert: str, auth: Tuple[str, str]): def start_monitor( - server_context: Tuple[str, str, Tuple[str, str]], callback, polling_interval=5 + server_context: tuple[str, str, tuple[str, str]], callback, polling_interval=5 ): """ Checks status on Everest server and calls callback when status changes @@ -259,8 +260,8 @@ def start_monitor( def _find_res_queue_system( - simulator: Optional[SimulatorConfig], - server: Optional[ServerConfig], + simulator: SimulatorConfig | None, + server: ServerConfig | None, ): queue_system_simulator: Literal["lsf", "local", "slurm", "torque"] = "local" if simulator is not None and simulator.queue_system is not None: @@ -282,8 +283,8 @@ def _find_res_queue_system( def get_server_queue_options( - simulator: Optional[SimulatorConfig], - server: Optional[ServerConfig], + simulator: SimulatorConfig | None, + server: ServerConfig | None, ) -> QueueOptions: script = ErtPluginManager().activate_script() or activate_script() queue_system = _find_res_queue_system(simulator, server) @@ -349,7 +350,7 @@ def decode(obj): def update_everserver_status( - everserver_status_path: str, status: ServerStatus, message: Optional[str] = None + everserver_status_path: str, status: ServerStatus, message: str | None = None ): """Update the everest server status with new status information""" new_status = {"status": status, "message": message} @@ -382,7 +383,7 @@ def everserver_status(everserver_status_path: str): } """ if os.path.exists(everserver_status_path): - with open(everserver_status_path, "r", encoding="utf-8") as f: + with open(everserver_status_path, encoding="utf-8") as f: return json.load(f, object_hook=ServerStatusEncoder.decode) else: return {"status": ServerStatus.never_run, "message": None} diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index c11f57d33bd..40bf0485e77 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -172,7 +172,7 @@ def _find_open_port(host, lower, upper) -> int: sock.bind((host, port)) sock.close() return port - except socket.error: + except OSError: logging.getLogger("everserver").info( "Port {} for host {} is taken".format(port, host) ) @@ -403,12 +403,10 @@ def _failed_realizations_messages(shared_data): # Find the set of jobs that failed. To keep the order in which they # are found in the queue, use a dict as sets are not ordered. failed_jobs = dict.fromkeys( - ( - job["name"] - for queue in shared_data[SIM_PROGRESS_ENDPOINT]["progress"] - for job in queue - if job["status"] == JOB_FAILURE - ) + job["name"] + for queue in shared_data[SIM_PROGRESS_ENDPOINT]["progress"] + for job in queue + if job["status"] == JOB_FAILURE ).keys() messages.append( "{} job failures caused by: {}".format(failed, ", ".join(failed_jobs)) diff --git a/src/everest/docs/generate_docs_from_config_spec.py b/src/everest/docs/generate_docs_from_config_spec.py index 7c28287bfec..a17b2f9d7c1 100644 --- a/src/everest/docs/generate_docs_from_config_spec.py +++ b/src/everest/docs/generate_docs_from_config_spec.py @@ -2,7 +2,6 @@ import inspect import re import sys -from typing import Dict, List, Optional from pydantic.fields import FieldInfo @@ -21,7 +20,7 @@ class ParsedField: description: str type: str is_required: bool - subfields: Optional[List["ParsedField"]] + subfields: list["ParsedField"] | None def doc_title(self) -> str: return f"{self.name} ({'required' if self.is_required else 'optional'})" @@ -46,7 +45,7 @@ def clean_type(self) -> str: return self.type -def parse_field_info(field_infos: Dict[str, FieldInfo]): +def parse_field_info(field_infos: dict[str, FieldInfo]): """ Extracts relevant info from a list of pydantic model fields into a convenient format of ParsedField items, to be used for further generation of docs @@ -140,9 +139,9 @@ def add_newline(self): def _generate_rst( - parsed_fields: List[ParsedField], + parsed_fields: list[ParsedField], level: int = 0, - builder: Optional[DocBuilder] = None, + builder: DocBuilder | None = None, extended=False, ): if not builder: diff --git a/src/everest/export.py b/src/everest/export.py index 1c64a2fa75e..de0a2de79ae 100644 --- a/src/everest/export.py +++ b/src/everest/export.py @@ -1,7 +1,7 @@ import os import re from enum import StrEnum -from typing import Any, Dict, List, Optional, Set +from typing import Any import pandas as pd from pandas import DataFrame @@ -45,7 +45,7 @@ def get_all(cls): ] -def filter_data(data: DataFrame, keyword_filters: Set[str]): +def filter_data(data: DataFrame, keyword_filters: set[str]): filtered_columns = [] for col in data.columns: @@ -57,14 +57,14 @@ def filter_data(data: DataFrame, keyword_filters: Set[str]): return data[filtered_columns] -def available_batches(optimization_output_dir: str) -> Set[int]: +def available_batches(optimization_output_dir: str) -> set[int]: snapshot = SebaSnapshot(optimization_output_dir).get_snapshot( filter_out_gradient=False, batches=None ) return {data.batch for data in snapshot.simulation_data} -def export_metadata(config: Optional[ExportConfig], optimization_output_dir: str): +def export_metadata(config: ExportConfig | None, optimization_output_dir: str): discard_gradient = True discard_rejected = True batches = None @@ -100,7 +100,7 @@ def export_metadata(config: Optional[ExportConfig], optimization_output_dir: str ): continue - md_row: Dict[str, Any] = { + md_row: dict[str, Any] = { MetaDataColumnNames.BATCH: data.batch, MetaDataColumnNames.SIM_AVERAGED_OBJECTIVE: data.sim_avg_obj, MetaDataColumnNames.IS_GRADIENT: data.is_gradient, @@ -147,12 +147,12 @@ def get_internalized_keys( config: ExportConfig, storage_path: str, optimization_output_path: str, - batch_ids: Optional[Set[int]] = None, + batch_ids: set[int] | None = None, ): if batch_ids is None: metadata = export_metadata(config, optimization_output_path) batch_ids = {data[MetaDataColumnNames.BATCH] for data in metadata} - internal_keys: Set = set() + internal_keys: set = set() with open_storage(storage_path, "r") as storage: for batch_id in batch_ids: case_name = f"batch_{batch_id}" @@ -177,24 +177,24 @@ def check_for_errors( config: ExportConfig, optimization_output_path: str, storage_path: str, - data_file_path: Optional[str], + data_file_path: str | None, ): """ Checks for possible errors when attempting to export current optimization case. """ export_ecl = True - export_errors: List[str] = [] + export_errors: list[str] = [] if config.batches: - _available_batches = available_batches(optimization_output_path) - for batch in set(config.batches).difference(_available_batches): + available_batches_ = available_batches(optimization_output_path) + for batch in set(config.batches).difference(available_batches_): export_errors.append( "Batch {} not found in optimization " "results. Skipping for current export." "".format(batch) ) - config.batches = list(set(config.batches).intersection(_available_batches)) + config.batches = list(set(config.batches).intersection(available_batches_)) if config.batches == []: export_errors.append( @@ -242,9 +242,9 @@ def check_for_errors( def export_data( - export_config: Optional[ExportConfig], + export_config: ExportConfig | None, output_dir: str, - data_file: Optional[str], + data_file: str | None, export_ecl=True, progress_callback=lambda _: None, ): @@ -291,7 +291,7 @@ def export_data( def load_simulation_data( - output_path: str, metadata: List[dict], progress_callback=lambda _: None + output_path: str, metadata: list[dict], progress_callback=lambda _: None ): """Export simulations to a pandas DataFrame @output_path optimization output folder path. diff --git a/src/everest/jobs/__init__.py b/src/everest/jobs/__init__.py index 4097323e3f4..63fb608b42d 100644 --- a/src/everest/jobs/__init__.py +++ b/src/everest/jobs/__init__.py @@ -34,19 +34,19 @@ def fetch_script_path(script_name: str) -> str: / rel_script_path ) - _scripts = {} + scripts = {} for script_name in script_names: - _scripts[script_name] = fetch_script_path(script_name) - globals()[script_name] = _scripts[script_name] + scripts[script_name] = fetch_script_path(script_name) + globals()[script_name] = scripts[script_name] - globals()["_scripts"] = _scripts + globals()["_scripts"] = scripts def fetch_script(script_name): if script_name in _scripts: # noqa F821 return _scripts[script_name] # noqa F821 else: - raise KeyError("Unknown script: %s" % script_name) + raise KeyError("Unknown script: {}".format(script_name)) _inject_scripts() diff --git a/src/everest/jobs/io/__init__.py b/src/everest/jobs/io/__init__.py index cc601cf9032..c529ce33ccf 100644 --- a/src/everest/jobs/io/__init__.py +++ b/src/everest/jobs/io/__init__.py @@ -23,7 +23,7 @@ def load_data(filename): json_err = err err_msg = "%s is neither yaml (err_msg=%s) nor json (err_msg=%s)" - raise IOError(err_msg % (filename, str(yaml_err), str(json_err))) + raise OSError(err_msg % (filename, str(yaml_err), str(json_err))) def _create_folders(filename): diff --git a/src/everest/jobs/templating/render.py b/src/everest/jobs/templating/render.py index 2b08f765d46..db2aba6be19 100644 --- a/src/everest/jobs/templating/render.py +++ b/src/everest/jobs/templating/render.py @@ -28,10 +28,10 @@ def _load_input(input_files): def _assert_input(input_files, template_file, output_file): for input_file in input_files: if not os.path.isfile(input_file): - raise ValueError("Input file: %s, does not exist.." % input_file) + raise ValueError("Input file: {}, does not exist..".format(input_file)) if not os.path.isfile(template_file): - raise ValueError("Template file: %s, does not exist.." % template_file) + raise ValueError("Template file: {}, does not exist..".format(template_file)) if not isinstance(output_file, str): raise TypeError("Expected output path to be a string") diff --git a/src/everest/optimizer/everest2ropt.py b/src/everest/optimizer/everest2ropt.py index 8df0a90ac21..06015f8cfb5 100644 --- a/src/everest/optimizer/everest2ropt.py +++ b/src/everest/optimizer/everest2ropt.py @@ -1,20 +1,15 @@ import os from collections import defaultdict +from collections.abc import Sequence from dataclasses import asdict, dataclass from typing import ( Any, - DefaultDict, - Dict, - List, - Optional, - Sequence, - Tuple, - Union, + Final, + TypeAlias, ) from ropt.config.enopt import EnOptConfig from ropt.enums import ConstraintType, PerturbationType, VariableType -from typing_extensions import Final, TypeAlias from everest.config import ( ControlConfig, @@ -31,10 +26,10 @@ ControlVariableGuessListConfig, ) -VariableName: TypeAlias = Tuple[str, str, int] -ControlName: TypeAlias = Union[Tuple[str, str], VariableName, List[VariableName]] -StrListDict: TypeAlias = DefaultDict[str, list] -IGNORE_KEYS: Final[Tuple[str, ...]] = ( +VariableName: TypeAlias = tuple[str, str, int] +ControlName: TypeAlias = tuple[str, str] | VariableName | list[VariableName] +StrListDict: TypeAlias = defaultdict[str, list] +IGNORE_KEYS: Final[tuple[str, ...]] = ( "enabled", "scaled_range", "auto_scale", @@ -45,10 +40,10 @@ def _collect_sampler( - sampler: Optional[SamplerConfig], - storage: Dict[str, Any], - control_name: Union[List[ControlName], ControlName, None] = None, -) -> Optional[Dict[str, Any]]: + sampler: SamplerConfig | None, + storage: dict[str, Any], + control_name: list[ControlName] | ControlName | None = None, +) -> dict[str, Any] | None: if sampler is None: return None map = sampler.model_dump(exclude_none=True, exclude={"backend", "method"}) @@ -64,37 +59,37 @@ def _collect_sampler( def _scale_translations( is_scale: bool, - _min: float, - _max: float, + min_: float, + max_: float, lower_bound: float, upper_bound: float, perturbation_type: PerturbationType, -) -> Tuple[float, float, int]: +) -> tuple[float, float, int]: if not is_scale: return 1.0, 0.0, perturbation_type.value - scale = (_max - _min) / (upper_bound - lower_bound) - return scale, _min - lower_bound * scale, PerturbationType.SCALED.value + scale = (max_ - min_) / (upper_bound - lower_bound) + return scale, min_ - lower_bound * scale, PerturbationType.SCALED.value @dataclass class Control: - name: Tuple[str, str] + name: tuple[str, str] enabled: bool lower_bounds: float upper_bounds: float - perturbation_magnitudes: Optional[float] - initial_values: List[float] + perturbation_magnitudes: float | None + initial_values: list[float] types: VariableType - scaled_range: Tuple[float, float] + scaled_range: tuple[float, float] auto_scale: bool - index: Optional[int] + index: int | None scales: float offsets: float perturbation_types: int def _resolve_everest_control( - variable: Union[ControlVariableConfig, ControlVariableGuessListConfig], + variable: ControlVariableConfig | ControlVariableGuessListConfig, group: ControlConfig, ) -> Control: scaled_range = variable.scaled_range or group.scaled_range or (0, 1.0) @@ -137,7 +132,7 @@ def _variable_initial_guess_list_injection( *, variables: StrListDict, gradients: StrListDict, -) -> List[VariableName]: +) -> list[VariableName]: guesses = len(control.initial_values) ropt_names = [(*control.name, index + 1) for index in range(guesses)] variables["names"].extend(ropt_names) @@ -241,13 +236,13 @@ def _parse_controls(controls: Sequence[ControlConfig], ropt_config): ropt_config["gradient"]["samplers"] = sampler_indices -def _parse_objectives(objective_functions: List[ObjectiveFunctionConfig], ropt_config): - names: List[str] = [] - scales: List[float] = [] - auto_scale: List[bool] = [] - weights: List[float] = [] - transform_indices: List[int] = [] - transforms: List = [] +def _parse_objectives(objective_functions: list[ObjectiveFunctionConfig], ropt_config): + names: list[str] = [] + scales: list[float] = [] + auto_scale: list[bool] = [] + weights: list[float] = [] + transform_indices: list[int] = [] + transforms: list = [] for objective in objective_functions: assert isinstance(objective.name, str) @@ -291,7 +286,7 @@ def _parse_objectives(objective_functions: List[ObjectiveFunctionConfig], ropt_c def _parse_input_constraints( - input_constraints: Optional[List[InputConstraintConfig]], + input_constraints: list[InputConstraintConfig] | None, ropt_config, formatted_names, ): @@ -331,19 +326,19 @@ def _add_input_constraint(rhs_value, coefficients, constraint_type): def _parse_output_constraints( - output_constraints: Optional[List[OutputConstraintConfig]], ropt_config + output_constraints: list[OutputConstraintConfig] | None, ropt_config ): if not output_constraints: return - names: List[str] = [] - rhs_values: List[float] = [] - scales: List[float] = [] - auto_scale: List[bool] = [] - types: List[ConstraintType] = [] + names: list[str] = [] + rhs_values: list[float] = [] + scales: list[float] = [] + auto_scale: list[bool] = [] + types: list[ConstraintType] = [] def _add_output_constraint( - rhs_value: Optional[float], constraint_type: ConstraintType, suffix=None + rhs_value: float | None, constraint_type: ConstraintType, suffix=None ): if rhs_value is not None: name = constr.name @@ -386,7 +381,7 @@ def _add_output_constraint( def _parse_optimization( - ever_opt: Optional[OptimizationConfig], + ever_opt: OptimizationConfig | None, has_output_constraints: bool, ropt_config, ): @@ -451,7 +446,7 @@ def _parse_optimization( if cvar_opts := ever_opt.cvar or None: # set up the configuration of the realization filter that implements cvar: if (percentile := cvar_opts.percentile) is not None: - cvar_config: Dict[str, Any] = { + cvar_config: dict[str, Any] = { "method": "cvar-objective", "options": {"percentile": percentile}, } @@ -485,8 +480,8 @@ def _parse_optimization( def _parse_model( - ever_model: Optional[ModelConfig], - ever_opt: Optional[OptimizationConfig], + ever_model: ModelConfig | None, + ever_opt: OptimizationConfig | None, ropt_config, ): if not ever_model: @@ -507,7 +502,7 @@ def _parse_model( def _parse_environment( - optimization_output_dir: str, random_seed: Optional[int], ropt_config + optimization_output_dir: str, random_seed: int | None, ropt_config ): ropt_config["optimizer"]["output_dir"] = os.path.abspath(optimization_output_dir) if random_seed is not None: @@ -521,7 +516,7 @@ def everest2ropt(ever_config: EverestConfig) -> EnOptConfig: the values are actually extracted, all the others are set to some more or less reasonable default """ - ropt_config: Dict[str, Any] = {} + ropt_config: dict[str, Any] = {} _parse_controls(ever_config.controls, ropt_config) diff --git a/src/everest/plugins/everest_plugin_manager.py b/src/everest/plugins/everest_plugin_manager.py index cc0cb52ca40..3a29ebd23a1 100644 --- a/src/everest/plugins/everest_plugin_manager.py +++ b/src/everest/plugins/everest_plugin_manager.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any import pluggy @@ -8,7 +8,7 @@ class EverestPluginManager(pluggy.PluginManager): def __init__(self, plugins=None) -> None: - super(EverestPluginManager, self).__init__(EVEREST) + super().__init__(EVEREST) self.add_hookspecs(hook_specs) if plugins is None: self.register(hook_impl) @@ -17,6 +17,6 @@ def __init__(self, plugins=None) -> None: for plugin in plugins: self.register(plugin) - def get_documentation(self) -> Dict[str, Any]: + def get_documentation(self) -> dict[str, Any]: docs = self.hook.get_forward_model_documentations() return {k: v for d in docs for k, v in d.items()} if docs else {} diff --git a/src/everest/plugins/hook_specs.py b/src/everest/plugins/hook_specs.py index 3c2281e5f60..c0d7b2e65f9 100644 --- a/src/everest/plugins/hook_specs.py +++ b/src/everest/plugins/hook_specs.py @@ -1,4 +1,5 @@ -from typing import List, Sequence, Type, TypeVar +from collections.abc import Sequence +from typing import TypeVar from everest.plugins import hookspec @@ -76,7 +77,7 @@ def get_forward_models_schemas(): @hookspec -def parse_forward_model_schema(path: str, schema: Type[T]): +def parse_forward_model_schema(path: str, schema: type[T]): """ Given a path and schema type, this hook will parse the file. """ @@ -106,7 +107,7 @@ def get_forward_model_documentations(): @hookspec(firstresult=True) -def custom_forward_model_outputs(forward_model_steps: List[str]): +def custom_forward_model_outputs(forward_model_steps: list[str]): """ Check if the given forward model steps will output to a file maching the defined everest objective diff --git a/src/everest/queue_driver/queue_driver.py b/src/everest/queue_driver/queue_driver.py index 33dcb1a1f09..0a90a248bad 100644 --- a/src/everest/queue_driver/queue_driver.py +++ b/src/everest/queue_driver/queue_driver.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Tuple +from typing import Any from ert.config import QueueSystem from everest.config import EverestConfig @@ -43,8 +43,8 @@ def _extract_ert_queue_options_from_simulator_config( - simulator: Optional[SimulatorConfig], queue_system -) -> List[Tuple[str, str, Any]]: + simulator: SimulatorConfig | None, queue_system +) -> list[tuple[str, str, Any]]: if simulator is None: simulator = SimulatorConfig() diff --git a/src/everest/simulator/everest_to_ert.py b/src/everest/simulator/everest_to_ert.py index fe717684889..17f280a2a24 100644 --- a/src/everest/simulator/everest_to_ert.py +++ b/src/everest/simulator/everest_to_ert.py @@ -3,7 +3,6 @@ import json import logging import os -from typing import DefaultDict, Dict, List, Union import everest from ert.config import ErtConfig, ExtParamConfig @@ -193,7 +192,7 @@ def _fetch_everest_jobs(ever_config: EverestConfig): assert ever_config.output_dir is not None job_storage = os.path.join(ever_config.output_dir, ".jobs") logging.getLogger(EVEREST).debug( - "Creating job description files in %s" % job_storage + "Creating job description files in {}".format(job_storage) ) if not os.path.isdir(job_storage): @@ -206,14 +205,14 @@ def _fetch_everest_jobs(ever_config: EverestConfig): script = everest.jobs.fetch_script(default_job) job_spec_file = os.path.join(job_storage, "_" + default_job) with open(job_spec_file, "w", encoding="utf-8") as f: - f.write("EXECUTABLE %s" % script) + f.write("EXECUTABLE {}".format(script)) ever_jobs.append(Job(name=default_job, source=job_spec_file)) return ever_jobs -def _job_to_dict(job: Union[dict, InstallJobConfig]) -> Union[dict, InstallJobConfig]: +def _job_to_dict(job: dict | InstallJobConfig) -> dict | InstallJobConfig: if type(job) is InstallJobConfig: return job.model_dump(exclude_none=True) return job @@ -288,7 +287,7 @@ def _internal_data_files(ever_config: EverestConfig): assert ever_config.output_dir is not None data_storage = os.path.join(ever_config.output_dir, ".internal_data") data_storage = os.path.realpath(data_storage) - logging.getLogger(EVEREST).debug("Storing internal data in %s" % data_storage) + logging.getLogger(EVEREST).debug("Storing internal data in {}".format(data_storage)) if not os.path.isdir(data_storage): os.makedirs(data_storage) @@ -493,16 +492,14 @@ def everest_to_ert_config(ever_config: EverestConfig) -> ErtConfig: ens_config = ert_config.ensemble_config def _get_variables( - variables: Union[ - List[ControlVariableConfig], List[ControlVariableGuessListConfig] - ], - ) -> Union[List[str], Dict[str, List[str]]]: + variables: list[ControlVariableConfig] | list[ControlVariableGuessListConfig], + ) -> list[str] | dict[str, list[str]]: if ( isinstance(variables[0], ControlVariableConfig) and getattr(variables[0], "index", None) is None ): return [var.name for var in variables] - result: DefaultDict[str, list] = collections.defaultdict(list) + result: collections.defaultdict[str, list] = collections.defaultdict(list) for variable in variables: if isinstance(variable, ControlVariableGuessListConfig): result[variable.name].extend( diff --git a/src/everest/simulator/simulator_cache.py b/src/everest/simulator/simulator_cache.py index bdaf9ef2a69..db4e3a7ae3c 100644 --- a/src/everest/simulator/simulator_cache.py +++ b/src/everest/simulator/simulator_cache.py @@ -1,6 +1,5 @@ from collections import defaultdict from itertools import count -from typing import DefaultDict, Dict, List, Optional, Tuple import numpy as np from numpy._typing import NDArray @@ -16,12 +15,12 @@ class SimulatorCache: def __init__(self) -> None: # Stores the realization/controls key, together with an ID. - self._keys: DefaultDict[int, List[Tuple[NDArray[np.float64], int]]] = ( + self._keys: defaultdict[int, list[tuple[NDArray[np.float64], int]]] = ( defaultdict(list) ) # Store objectives and constraints by ID: - self._objectives: Dict[int, NDArray[np.float64]] = {} - self._constraints: Dict[int, NDArray[np.float64]] = {} + self._objectives: dict[int, NDArray[np.float64]] = {} + self._constraints: dict[int, NDArray[np.float64]] = {} # Generate unique ID's: self._counter = count() @@ -32,7 +31,7 @@ def add_simulation_results( real_id: int, control_values: NDArray[np.float64], objectives: NDArray[np.float64], - constraints: Optional[NDArray[np.float64]], + constraints: NDArray[np.float64] | None, ): cache_id = next(self._counter) self._keys[real_id].append((control_values[sim_idx, :].copy(), cache_id)) @@ -40,9 +39,7 @@ def add_simulation_results( if constraints is not None: self._constraints[cache_id] = constraints[sim_idx, ...].copy() - def find_key( - self, real_id: int, control_vector: NDArray[np.float64] - ) -> Optional[int]: + def find_key(self, real_id: int, control_vector: NDArray[np.float64]) -> int | None: # Brute-force search, premature optimization is the root of all evil: for cached_vector, cache_id in self._keys.get(real_id, []): if np.allclose( diff --git a/src/everest/util/__init__.py b/src/everest/util/__init__.py index 4319c202e9f..dfe660ba154 100644 --- a/src/everest/util/__init__.py +++ b/src/everest/util/__init__.py @@ -68,13 +68,13 @@ def _roll_dir(old_name): old_name = os.path.realpath(old_name) new_name = old_name + datetime.datetime.utcnow().strftime("__%Y-%m-%d_%H.%M.%S.%f") os.rename(old_name, new_name) - logging.getLogger(EVEREST).info("renamed %s to %s" % (old_name, new_name)) + logging.getLogger(EVEREST).info("renamed {} to {}".format(old_name, new_name)) def load_deck(fname): """Take a .DATA file and return an opm.io.Deck.""" if not os.path.exists(fname): - raise IOError('No such data file "%s".' % fname) + raise OSError('No such data file "{}".'.format(fname)) if not has_opm(): raise RuntimeError("Cannot load ECL files, opm could not be imported") diff --git a/src/everest/util/async_run.py b/src/everest/util/async_run.py index e62550995b9..3a308497469 100644 --- a/src/everest/util/async_run.py +++ b/src/everest/util/async_run.py @@ -18,7 +18,7 @@ def async_run(function, on_finished=None, on_error=None): class _AsyncRunner(Thread): def __init__(self, function=None, on_finished=None, on_error=None): - super(_AsyncRunner, self).__init__() + super().__init__() self._function = function self._on_finished = on_finished self._on_error = on_error diff --git a/src/everest/util/forward_models.py b/src/everest/util/forward_models.py index 3ae8d8af0c0..a4f4a0af266 100644 --- a/src/everest/util/forward_models.py +++ b/src/everest/util/forward_models.py @@ -1,4 +1,4 @@ -from typing import List, Set, Type, TypeVar +from typing import TypeVar from pydantic import BaseModel, ValidationError @@ -16,12 +16,12 @@ def collect_forward_model_schemas(): return {} -def lint_forward_model_job(job: str, args) -> List[str]: +def lint_forward_model_job(job: str, args) -> list[str]: return pm.hook.lint_forward_model(job=job, args=args) def check_forward_model_objective( - forward_model_steps: List[str], objectives: Set[str] + forward_model_steps: list[str], objectives: set[str] ) -> None: if not objectives or not forward_model_steps: return @@ -39,7 +39,7 @@ def check_forward_model_objective( ) -def parse_forward_model_file(path: str, schema: Type[T], message: str) -> T: +def parse_forward_model_file(path: str, schema: type[T], message: str) -> T: try: res = pm.hook.parse_forward_model_schema(path=path, schema=schema) if res: diff --git a/test-data/ert/batch_sim/workflows/jobs/realization_number.py b/test-data/ert/batch_sim/workflows/jobs/realization_number.py index 2f3cf6240bd..687bbd15f1b 100755 --- a/test-data/ert/batch_sim/workflows/jobs/realization_number.py +++ b/test-data/ert/batch_sim/workflows/jobs/realization_number.py @@ -8,7 +8,7 @@ def add_file_to_realization_runpaths(runpath_file): - with open(runpath_file, "r", encoding="utf-8") as fh: + with open(runpath_file, encoding="utf-8") as fh: runpath_file_lines = fh.readlines() for line in runpath_file_lines: diff --git a/test-data/ert/heat_equation/generate_files.py b/test-data/ert/heat_equation/generate_files.py index 769ceefe82e..2a4866d8ff4 100644 --- a/test-data/ert/heat_equation/generate_files.py +++ b/test-data/ert/heat_equation/generate_files.py @@ -2,8 +2,8 @@ Contains code that was used to generate files expected by ert. """ +from collections.abc import Callable from textwrap import dedent -from typing import Callable, List import numpy as np import numpy.typing as npt @@ -27,7 +27,7 @@ def create_egrid_file(): def make_observations( - coordinates: List[Coordinate], + coordinates: list[Coordinate], times: npt.NDArray[np.int_], field: npt.NDArray[np.float64], error: Callable, @@ -57,7 +57,7 @@ def make_observations( # See documentation for details. value = field[k, coordinate.x, coordinate.y] sd = error(value) - _df = pd.DataFrame( + df_ = pd.DataFrame( { "k": [k], "x": [coordinate.x], @@ -66,7 +66,7 @@ def make_observations( "sd": [sd], } ) - d = pd.concat([d, _df]) + d = pd.concat([d, df_]) d = d.set_index(["k", "x", "y"], verify_integrity=True) return d diff --git a/test-data/ert/heat_equation/heat_equation.py b/test-data/ert/heat_equation/heat_equation.py index e36862ea1aa..6ec0231cacd 100755 --- a/test-data/ert/heat_equation/heat_equation.py +++ b/test-data/ert/heat_equation/heat_equation.py @@ -2,7 +2,6 @@ """Partial Differential Equations to use as forward models.""" import sys -from typing import Optional import geostat import numpy as np @@ -19,14 +18,14 @@ def heat_equation( k_start: int, k_end: int, rng: np.random.Generator, - scale: Optional[float] = None, + scale: float | None = None, ) -> npt.NDArray[np.float64]: """2D heat equation that suppoheat_erts field of heat coefficients. Based on: https://levelup.gitconnected.com/solving-2d-heat-equation-numerically-using-python-3334004aa01a """ - _u = u.copy() + u_ = u.copy() nx = u.shape[1] # number of grid cells assert cond.shape == (nx, nx) @@ -36,20 +35,20 @@ def heat_equation( for i in range(1, plate_length - 1, dx): for j in range(1, plate_length - 1, dx): noise = rng.normal(scale=scale) if scale is not None else 0 - _u[k + 1, i, j] = ( + u_[k + 1, i, j] = ( gamma[i, j] * ( - _u[k][i + 1][j] - + _u[k][i - 1][j] - + _u[k][i][j + 1] - + _u[k][i][j - 1] - - 4 * _u[k][i][j] + u_[k][i + 1][j] + + u_[k][i - 1][j] + + u_[k][i][j + 1] + + u_[k][i][j - 1] + - 4 * u_[k][i][j] ) - + _u[k][i][j] + + u_[k][i][j] + noise ) - return _u + return u_ def sample_prior_conductivity(ensemble_size, nx, rng): diff --git a/test-data/ert/snake_oil/forward_models/snake_oil_simulator.py b/test-data/ert/snake_oil/forward_models/snake_oil_simulator.py index 144fdfc74c0..4efbbdea745 100755 --- a/test-data/ert/snake_oil/forward_models/snake_oil_simulator.py +++ b/test-data/ert/snake_oil/forward_models/snake_oil_simulator.py @@ -11,7 +11,7 @@ def globalIndex(i, j, k, nx=10, ny=10, nz=10): def readParameters(filename): params = {} - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: for line in f: key, value = line.split(":", 1) params[key] = value.strip() @@ -151,7 +151,7 @@ def runSimulator(simulator, history_simulator, time_step_count) -> Summary: def roundedInt(value): - return int(round(float(value))) + return round(float(value)) if __name__ == "__main__": diff --git a/test-data/ert/snake_oil_field/forward_models/snake_oil_simulator.py b/test-data/ert/snake_oil_field/forward_models/snake_oil_simulator.py index 098ea37468e..ebb6d628cd6 100755 --- a/test-data/ert/snake_oil_field/forward_models/snake_oil_simulator.py +++ b/test-data/ert/snake_oil_field/forward_models/snake_oil_simulator.py @@ -11,7 +11,7 @@ def globalIndex(i, j, k, nx=10, ny=10): def read_seed(filename): params = {} - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: for line in f: key, value = line.split(":") params[key] = value.strip() @@ -21,7 +21,7 @@ def read_seed(filename): def read_parameters(filename): params = {} - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: for line in f: key, value = line.split(" ") _, name = key.split(":") @@ -146,7 +146,7 @@ def runSimulator(simulator, history_simulator, time_step_count): def roundedInt(value): - return int(round(float(value))) + return round(float(value)) if __name__ == "__main__": diff --git a/test-data/everest/math_func/jobs/adv_distance3.py b/test-data/everest/math_func/jobs/adv_distance3.py index 200cec442a9..ee1f5ea4e80 100755 --- a/test-data/everest/math_func/jobs/adv_distance3.py +++ b/test-data/everest/math_func/jobs/adv_distance3.py @@ -12,7 +12,7 @@ def compute_distance_squared(p, q): def read_point(filename): - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: point = json.load(f) x = point["x"] return x["0"], x["1"], x["2"] diff --git a/test-data/everest/math_func/jobs/adv_dump_controls.py b/test-data/everest/math_func/jobs/adv_dump_controls.py index 19b1bb9a435..457c48be1fb 100755 --- a/test-data/everest/math_func/jobs/adv_dump_controls.py +++ b/test-data/everest/math_func/jobs/adv_dump_controls.py @@ -12,7 +12,7 @@ def main(argv): arg_parser.add_argument("--out-suffix", type=str, default="") opts, _ = arg_parser.parse_known_args(args=argv) - with open(opts.controls_file, "r", encoding="utf-8") as f: + with open(opts.controls_file, encoding="utf-8") as f: controls = json.load(f) for name, indices in controls.items(): diff --git a/test-data/everest/math_func/jobs/discrete.py b/test-data/everest/math_func/jobs/discrete.py index 44bb659b99c..3853760317f 100755 --- a/test-data/everest/math_func/jobs/discrete.py +++ b/test-data/everest/math_func/jobs/discrete.py @@ -10,7 +10,7 @@ def compute_func(x, y): def read_point(filename): - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: point = json.load(f) return point["x"], point["y"] diff --git a/test-data/everest/math_func/jobs/distance3.py b/test-data/everest/math_func/jobs/distance3.py index bdfb8f94a3d..a7a2be714fe 100755 --- a/test-data/everest/math_func/jobs/distance3.py +++ b/test-data/everest/math_func/jobs/distance3.py @@ -12,7 +12,7 @@ def compute_distance_squared(p, q): def read_point(filename): - with open(filename, "r", encoding="utf-8") as f: + with open(filename, encoding="utf-8") as f: point = json.load(f) return point["x"], point["y"], point["z"] diff --git a/test-data/everest/math_func/jobs/dump_controls.py b/test-data/everest/math_func/jobs/dump_controls.py index be45295703d..de7c431f674 100755 --- a/test-data/everest/math_func/jobs/dump_controls.py +++ b/test-data/everest/math_func/jobs/dump_controls.py @@ -12,7 +12,7 @@ def main(argv): arg_parser.add_argument("--out-suffix", type=str, default="") opts, _ = arg_parser.parse_known_args(args=argv) - with open(opts.controls_file, "r", encoding="utf-8") as f: + with open(opts.controls_file, encoding="utf-8") as f: controls = json.load(f) for k, v in controls.items(): diff --git a/tests/ert/__init__.py b/tests/ert/__init__.py index ceea10c5153..ee05dbc228c 100644 --- a/tests/ert/__init__.py +++ b/tests/ert/__init__.py @@ -1,8 +1,9 @@ import importlib.util import sys +from collections.abc import Sequence from copy import deepcopy from datetime import datetime -from typing import Any, Dict, Optional, Sequence +from typing import Any from pydantic import BaseModel @@ -27,16 +28,16 @@ def import_from_location(name, location): class SnapshotBuilder(BaseModel): - fm_steps: Dict[str, FMStepSnapshot] = {} - metadata: Dict[str, Any] = {} + fm_steps: dict[str, FMStepSnapshot] = {} + metadata: dict[str, Any] = {} def build( self, real_ids: Sequence[str], - status: Optional[str], - exec_hosts: Optional[str] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, + status: str | None, + exec_hosts: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, ) -> EnsembleSnapshot: snapshot = EnsembleSnapshot() snapshot._ensemble_state = status @@ -60,14 +61,14 @@ def add_fm_step( self, fm_step_id: str, index: str, - name: Optional[str], - status: Optional[str], - current_memory_usage: Optional[str] = None, - max_memory_usage: Optional[str] = None, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, - stdout: Optional[str] = None, - stderr: Optional[str] = None, + name: str | None, + status: str | None, + current_memory_usage: str | None = None, + max_memory_usage: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + stdout: str | None = None, + stderr: str | None = None, ) -> "SnapshotBuilder": self.fm_steps[fm_step_id] = _filter_nones( FMStepSnapshot( diff --git a/tests/ert/conftest.py b/tests/ert/conftest.py index 8d8d5f528b7..8fbe828cc60 100644 --- a/tests/ert/conftest.py +++ b/tests/ert/conftest.py @@ -210,16 +210,16 @@ def copy_heat_equation(copy_case): pytest.param(0, marks=pytest.mark.xdist_group(name="snake_oil_case_storage")) ], ) -def fixture_copy_snake_oil_case_storage(_shared_snake_oil_case, tmp_path, monkeypatch): +def fixture_copy_snake_oil_case_storage(shared_snake_oil_case, tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - shutil.copytree(_shared_snake_oil_case, "test_data") + shutil.copytree(shared_snake_oil_case, "test_data") monkeypatch.chdir("test_data") @pytest.fixture -def copy_heat_equation_storage(_shared_heat_equation, tmp_path, monkeypatch): +def copy_heat_equation_storage(shared_heat_equation, tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - shutil.copytree(_shared_heat_equation, "heat_equation") + shutil.copytree(shared_heat_equation, "heat_equation") monkeypatch.chdir("heat_equation") diff --git a/tests/ert/performance_tests/test_analysis.py b/tests/ert/performance_tests/test_analysis.py index 090ce65432e..282280a8958 100644 --- a/tests/ert/performance_tests/test_analysis.py +++ b/tests/ert/performance_tests/test_analysis.py @@ -7,9 +7,7 @@ import xtgeo from scipy.ndimage import gaussian_filter -from ert.analysis import ( - smoother_update, -) +from ert.analysis import smoother_update from ert.config import Field, GenDataConfig from ert.config.analysis_config import UpdateSettings from ert.config.analysis_module import ESSettings diff --git a/tests/ert/performance_tests/test_dark_storage_performance.py b/tests/ert/performance_tests/test_dark_storage_performance.py index 758530d9cbb..396a76aab36 100644 --- a/tests/ert/performance_tests/test_dark_storage_performance.py +++ b/tests/ert/performance_tests/test_dark_storage_performance.py @@ -4,8 +4,9 @@ import os import time from asyncio import get_event_loop +from collections.abc import Awaitable from datetime import datetime, timedelta -from typing import Awaitable, TypeVar +from typing import TypeVar from urllib.parse import quote import memray diff --git a/tests/ert/performance_tests/test_obs_and_responses_performance.py b/tests/ert/performance_tests/test_obs_and_responses_performance.py index 3dc279ddb46..9e27a1275bf 100644 --- a/tests/ert/performance_tests/test_obs_and_responses_performance.py +++ b/tests/ert/performance_tests/test_obs_and_responses_performance.py @@ -3,7 +3,6 @@ import time from dataclasses import dataclass from textwrap import dedent -from typing import List import memray import numpy as np @@ -247,7 +246,7 @@ class _Benchmark: # hence they are all declared here # Note: Adjusting num responses did not seem # to have a very big impact on performance. -_BenchMarks: List[_Benchmark] = [ +_BenchMarks: list[_Benchmark] = [ _Benchmark( alias="small", config=_UpdatePerfTestConfig( diff --git a/tests/ert/performance_tests/test_read_summary.py b/tests/ert/performance_tests/test_read_summary.py index c1c754c7e64..199944f7044 100644 --- a/tests/ert/performance_tests/test_read_summary.py +++ b/tests/ert/performance_tests/test_read_summary.py @@ -1,9 +1,7 @@ from hypothesis import given from ert.config._read_summary import read_summary -from tests.ert.unit_tests.config.summary_generator import ( - summaries, -) +from tests.ert.unit_tests.config.summary_generator import summaries @given(summaries()) diff --git a/tests/ert/ui_tests/cli/analysis/test_es_update.py b/tests/ert/ui_tests/cli/analysis/test_es_update.py index bc432fce02f..39897af2482 100644 --- a/tests/ert/ui_tests/cli/analysis/test_es_update.py +++ b/tests/ert/ui_tests/cli/analysis/test_es_update.py @@ -140,10 +140,10 @@ def sample_prior(nx, ny): # Check that surfaces defined in INIT_FILES are not changed by ERT surf_prior = ens_prior.load_parameters("TOP", list(range(ensemble_size)))["values"] for i in range(ensemble_size): - _prior_init = surface_from_file( + prior_init = surface_from_file( f"surface/surf_init_{i}.irap", fformat="irap_ascii", dtype=np.float32 ) - np.testing.assert_array_equal(surf_prior[i], _prior_init.values.data) + np.testing.assert_array_equal(surf_prior[i], prior_init.values.data) surf_posterior = ens_posterior.load_parameters("TOP", list(range(ensemble_size)))[ "values" diff --git a/tests/ert/ui_tests/cli/test_cli.py b/tests/ert/ui_tests/cli/test_cli.py index fb775e6a886..91b6fc7b843 100644 --- a/tests/ert/ui_tests/cli/test_cli.py +++ b/tests/ert/ui_tests/cli/test_cli.py @@ -67,7 +67,7 @@ def test_test_run_on_lsf_configuration_works_with_no_errors(tmp_path): @pytest.mark.usefixtures("copy_poly_case") def test_that_the_cli_raises_exceptions_when_parameters_are_missing(mode): with ( - open("poly.ert", "r", encoding="utf-8") as fin, + open("poly.ert", encoding="utf-8") as fin, open("poly-no-gen-kw.ert", "w", encoding="utf-8") as fout, ): for line in fin: @@ -153,7 +153,7 @@ def test_surface_init_fails_during_forward_model_callback(): @pytest.mark.usefixtures("copy_snake_oil_field") def test_unopenable_observation_config_fails_gracefully(): config_file_name = "snake_oil_field.ert" - with open(config_file_name, mode="r", encoding="utf-8") as config_file_handler: + with open(config_file_name, encoding="utf-8") as config_file_handler: content_lines = config_file_handler.read().splitlines() index_line_with_observation_config = next( index @@ -186,7 +186,7 @@ def test_that_the_model_raises_exception_if_successful_realizations_less_than_mi mode, ): with ( - open("poly.ert", "r", encoding="utf-8") as fin, + open("poly.ert", encoding="utf-8") as fin, open("failing_realizations.ert", "w", encoding="utf-8") as fout, ): for line in fin: @@ -231,7 +231,7 @@ def test_that_the_model_warns_when_active_realizations_less_min_realizations(mod A warning is issued when NUM_REALIZATIONS is higher than active_realizations. """ with ( - open("poly.ert", "r", encoding="utf-8") as fin, + open("poly.ert", encoding="utf-8") as fin, open("poly_lower_active_reals.ert", "w", encoding="utf-8") as fout, ): for line in fin: @@ -883,7 +883,7 @@ def test_that_log_is_cleaned_up_from_repeated_forward_model_steps(caplog): there are repeated forward models """ with ( - open("poly.ert", "r", encoding="utf-8") as fin, + open("poly.ert", encoding="utf-8") as fin, open("poly_repeated_forward_model_steps.ert", "w", encoding="utf-8") as fout, ): forward_model_steps = ["FORWARD_MODEL poly_eval\n"] * 5 diff --git a/tests/ert/ui_tests/cli/test_parameter_passing.py b/tests/ert/ui_tests/cli/test_parameter_passing.py index b9bf9ccebb7..ddc7b8f23ed 100644 --- a/tests/ert/ui_tests/cli/test_parameter_passing.py +++ b/tests/ert/ui_tests/cli/test_parameter_passing.py @@ -11,7 +11,7 @@ from datetime import datetime from enum import Enum, auto from pathlib import Path -from typing import Literal, Optional, Tuple +from typing import Literal import cwrap import hypothesis.strategies as st @@ -141,7 +141,7 @@ def write_grid_file(self, grid_name: str, grid_format: Literal["grid", "egrid"]) else: raise ValueError() - def _random_values(self, shape: Tuple[int, ...], name: str): + def _random_values(self, shape: tuple[int, ...], name: str): return self.data.draw( arrays( elements=st.floats(min_value=2.0, max_value=4.0, width=32), @@ -231,10 +231,10 @@ class FieldParameter(Parameter): name: str infformat: FieldFileFormat outfformat: FieldFileFormat - min: Optional[float] - max: Optional[float] - input_transform: Optional[Transform] - output_transform: Optional[Transform] + min: float | None + max: float | None + input_transform: Transform | None + output_transform: Transform | None forward_init: bool @property diff --git a/tests/ert/ui_tests/gui/conftest.py b/tests/ert/ui_tests/gui/conftest.py index 2f95e01bd02..c9cb07e37c9 100644 --- a/tests/ert/ui_tests/gui/conftest.py +++ b/tests/ert/ui_tests/gui/conftest.py @@ -5,10 +5,11 @@ import shutil import stat import time +from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path from textwrap import dedent -from typing import Iterator, Type, TypeVar +from typing import TypeVar from unittest.mock import MagicMock, Mock import pytest @@ -84,9 +85,7 @@ def _new_poly_example(source_root, destination, num_realizations: int = 20): @contextmanager -def _open_main_window( - path, -) -> Iterator[tuple[ErtMainWindow, Storage, ErtConfig]]: +def _open_main_window(path) -> Iterator[tuple[ErtMainWindow, Storage, ErtConfig]]: args_mock = Mock() args_mock.config = str(path) with ErtPluginContext(): @@ -184,9 +183,9 @@ def _evaluate(coeffs, x): @pytest.fixture -def esmda_has_run(_esmda_run, tmp_path, monkeypatch): +def esmda_has_run(esmda_run, tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) - shutil.copytree(_esmda_run, tmp_path, dirs_exist_ok=True) + shutil.copytree(esmda_run, tmp_path, dirs_exist_ok=True) with ( _open_main_window(tmp_path / "poly.ert") as ( gui, @@ -390,17 +389,17 @@ def handle_add_dialog(): V = TypeVar("V") -def wait_for_child(gui, qtbot: QtBot, typ: Type[V], timeout=5000, **kwargs) -> V: +def wait_for_child(gui, qtbot: QtBot, typ: type[V], timeout=5000, **kwargs) -> V: qtbot.waitUntil(lambda: gui.findChild(typ) is not None, timeout=timeout) return get_child(gui, typ, **kwargs) -def get_child(gui: QWidget, typ: Type[V], *args, **kwargs) -> V: +def get_child(gui: QWidget, typ: type[V], *args, **kwargs) -> V: child = gui.findChild(typ, *args, **kwargs) assert isinstance(child, typ) return child -def get_children(gui: QWidget, typ: Type[V], *args, **kwargs) -> list[V]: +def get_children(gui: QWidget, typ: type[V], *args, **kwargs) -> list[V]: children: list[typ] = gui.findChildren(typ, *args, **kwargs) return children diff --git a/tests/ert/ui_tests/gui/test_full_manual_update_workflow.py b/tests/ert/ui_tests/gui/test_full_manual_update_workflow.py index 674423ae0ee..d1784c5ee46 100644 --- a/tests/ert/ui_tests/gui/test_full_manual_update_workflow.py +++ b/tests/ert/ui_tests/gui/test_full_manual_update_workflow.py @@ -3,12 +3,7 @@ import numpy as np from qtpy.QtCore import Qt -from qtpy.QtWidgets import ( - QComboBox, - QToolButton, - QTreeView, - QWidget, -) +from qtpy.QtWidgets import QComboBox, QToolButton, QTreeView, QWidget from ert.data import MeasuredData from ert.gui.simulation.evaluate_ensemble_panel import EvaluateEnsemblePanel diff --git a/tests/ert/ui_tests/gui/test_load_results_manually.py b/tests/ert/ui_tests/gui/test_load_results_manually.py index e7ecc60619f..8b72648fb61 100644 --- a/tests/ert/ui_tests/gui/test_load_results_manually.py +++ b/tests/ert/ui_tests/gui/test_load_results_manually.py @@ -5,10 +5,7 @@ from ert.gui.ertwidgets.ensembleselector import EnsembleSelector from ert.gui.tools.load_results import LoadResultsPanel -from .conftest import ( - get_child, - wait_for_child, -) +from .conftest import get_child, wait_for_child def test_validation(ensemble_experiment_has_run_no_failure, qtbot): diff --git a/tests/ert/ui_tests/gui/test_main_window.py b/tests/ert/ui_tests/gui/test_main_window.py index 69f5a420b81..daabc9047c0 100644 --- a/tests/ert/ui_tests/gui/test_main_window.py +++ b/tests/ert/ui_tests/gui/test_main_window.py @@ -4,7 +4,6 @@ import stat from pathlib import Path from textwrap import dedent -from typing import List from unittest.mock import MagicMock, Mock, patch import numpy as np @@ -145,7 +144,7 @@ def test_gui_shows_a_warning_and_disables_update_when_parameters_are_missing( qapp, tmp_path ): with ( - open("poly.ert", "r", encoding="utf-8") as fin, + open("poly.ert", encoding="utf-8") as fin, open("poly-no-gen-kw.ert", "w", encoding="utf-8") as fout, ): for line in fin: @@ -360,10 +359,7 @@ def test_that_the_plot_window_contains_the_expected_elements( plot_window.close() -def test_that_the_manage_experiments_tool_can_be_used( - esmda_has_run, - qtbot, -): +def test_that_the_manage_experiments_tool_can_be_used(esmda_has_run, qtbot): gui = esmda_has_run button_manage_experiments = gui.findChild(QToolButton, "button_Manage_experiments") @@ -617,7 +613,7 @@ def test_right_click_plot_button_opens_external_plotter(qtbot, storage, monkeypa button_plot_tool = gui.findChild(SidebarToolButton, "button_Create_plot") assert button_plot_tool - def top_level_plotter_windows() -> List[QWindow]: + def top_level_plotter_windows() -> list[QWindow]: top_level_plot_windows = [] top_level_windows = QApplication.topLevelWindows() for win in top_level_windows: diff --git a/tests/ert/ui_tests/gui/test_missing_runpath.py b/tests/ert/ui_tests/gui/test_missing_runpath.py index 2cdaf38f1d5..e6118037a85 100644 --- a/tests/ert/ui_tests/gui/test_missing_runpath.py +++ b/tests/ert/ui_tests/gui/test_missing_runpath.py @@ -2,9 +2,7 @@ from contextlib import suppress from qtpy.QtCore import QTimer -from qtpy.QtWidgets import ( - QLabel, -) +from qtpy.QtWidgets import QLabel from ert.ensemble_evaluator.state import ENSEMBLE_STATE_FAILED from ert.gui.simulation.run_dialog import RunDialog diff --git a/tests/ert/ui_tests/gui/test_restart_ensemble_experiment.py b/tests/ert/ui_tests/gui/test_restart_ensemble_experiment.py index a6b29efa13c..04910bb8890 100644 --- a/tests/ert/ui_tests/gui/test_restart_ensemble_experiment.py +++ b/tests/ert/ui_tests/gui/test_restart_ensemble_experiment.py @@ -2,7 +2,6 @@ import random import stat from textwrap import dedent -from typing import Set from qtpy.QtCore import Qt, QTimer from qtpy.QtWidgets import QComboBox, QMessageBox, QWidget @@ -20,7 +19,7 @@ def test_restart_failed_realizations(opened_main_window_poly, qtbot): """ gui = opened_main_window_poly - def write_poly_eval(failing_reals: Set[int]): + def write_poly_eval(failing_reals: set[int]): with open("poly_eval.py", "w", encoding="utf-8") as f: f.write( dedent( diff --git a/tests/ert/ui_tests/gui/test_restart_no_responses_and_parameters.py b/tests/ert/ui_tests/gui/test_restart_no_responses_and_parameters.py index 0cf03a2660d..0cc9cc1cc37 100644 --- a/tests/ert/ui_tests/gui/test_restart_no_responses_and_parameters.py +++ b/tests/ert/ui_tests/gui/test_restart_no_responses_and_parameters.py @@ -1,8 +1,8 @@ import os import stat +from collections.abc import Generator from contextlib import contextmanager from textwrap import dedent -from typing import Generator, Tuple from unittest.mock import Mock import pytest @@ -29,7 +29,7 @@ @contextmanager def _open_main_window( path, -) -> Generator[Tuple[ErtMainWindow, Storage, ErtConfig], None, None]: +) -> Generator[tuple[ErtMainWindow, Storage, ErtConfig], None, None]: with open("forward_model.py", "w", encoding="utf-8") as f: f.write( dedent( diff --git a/tests/ert/ui_tests/gui/test_single_test_run.py b/tests/ert/ui_tests/gui/test_single_test_run.py index a1d60b178d0..536808c355e 100644 --- a/tests/ert/ui_tests/gui/test_single_test_run.py +++ b/tests/ert/ui_tests/gui/test_single_test_run.py @@ -2,10 +2,7 @@ import shutil from qtpy.QtCore import Qt -from qtpy.QtWidgets import ( - QComboBox, - QWidget, -) +from qtpy.QtWidgets import QComboBox, QWidget from ert.gui.simulation.experiment_panel import ExperimentPanel from ert.gui.simulation.run_dialog import RunDialog diff --git a/tests/ert/ui_tests/gui/test_workflow_tool.py b/tests/ert/ui_tests/gui/test_workflow_tool.py index 136647e2f90..dc348b6e02b 100644 --- a/tests/ert/ui_tests/gui/test_workflow_tool.py +++ b/tests/ert/ui_tests/gui/test_workflow_tool.py @@ -1,7 +1,7 @@ import os +from collections.abc import Generator from contextlib import contextmanager from textwrap import dedent -from typing import Generator, Tuple from unittest.mock import Mock import pytest @@ -24,7 +24,7 @@ @contextmanager def _open_main_window( path, -) -> Generator[Tuple[ErtMainWindow, Storage, ErtConfig], None, None]: +) -> Generator[tuple[ErtMainWindow, Storage, ErtConfig], None, None]: (path / "config.ert").write_text( dedent(""" QUEUE_SYSTEM LOCAL diff --git a/tests/ert/unit_tests/config/config_dict_generator.py b/tests/ert/unit_tests/config/config_dict_generator.py index d747afbddf9..e0485a1ffda 100644 --- a/tests/ert/unit_tests/config/config_dict_generator.py +++ b/tests/ert/unit_tests/config/config_dict_generator.py @@ -6,7 +6,7 @@ from collections import defaultdict from dataclasses import dataclass, fields from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Literal from warnings import filterwarnings import hypothesis.strategies as st @@ -142,7 +142,7 @@ def valid_queue_options(queue_system: str): ] -queue_options_by_type: Dict[str, Dict[str, List[str]]] = defaultdict(dict) +queue_options_by_type: dict[str, dict[str, list[str]]] = defaultdict(dict) for system, options in queue_systems_and_options.items(): queue_options_by_type["string"][system.name] = [ field.name.upper() @@ -254,23 +254,23 @@ def random_forward_model_names(draw, some_words, some_file_names): @dataclass class ErtConfigValues: num_realizations: PositiveInt - eclbase: Optional[str] + eclbase: str | None runpath_file: str - run_template: List[str] + run_template: list[str] enkf_alpha: float update_log_path: str std_cutoff: float max_runtime: PositiveInt min_realizations: str - define: List[Tuple[str, str]] - forward_model: Tuple[str, List[Tuple[str, str]]] - simulation_job: List[List[str]] + define: list[tuple[str, str]] + forward_model: tuple[str, list[tuple[str, str]]] + simulation_job: list[list[str]] stop_long_running: bool - data_kw_key: List[Tuple[str, str]] + data_kw_key: list[tuple[str, str]] data_file: str grid_file: str job_script: str - jobname: Optional[str] + jobname: str | None runpath: str enspath: str time_map: str @@ -278,23 +278,23 @@ class ErtConfigValues: history_source: HistorySource refcase: str gen_kw_export_name: str - field: List[Tuple[str, ...]] - gen_data: List[Tuple[str, ...]] + field: list[tuple[str, ...]] + gen_data: list[tuple[str, ...]] max_submit: PositiveInt num_cpu: PositiveInt queue_system: Literal["LSF", "LOCAL", "TORQUE", "SLURM"] - queue_option: List[Union[Tuple[str, str], Tuple[str, str, str]]] - analysis_set_var: List[Tuple[str, str, Any]] - install_job: List[Tuple[str, str]] - install_job_directory: List[str] + queue_option: list[tuple[str, str] | tuple[str, str, str]] + analysis_set_var: list[tuple[str, str, Any]] + install_job: list[tuple[str, str]] + install_job_directory: list[str] license_path: str random_seed: int - setenv: List[Tuple[str, str]] - observations: List[Observation] + setenv: list[tuple[str, str]] + observations: list[Observation] refcase_smspec: Smspec refcase_unsmry: Unsmry egrid: EGrid - datetimes: List[datetime.datetime] + datetimes: list[datetime.datetime] def to_config_dict(self, config_file, cwd, all_defines=True): result = { @@ -425,7 +425,7 @@ def ert_config_values(draw, use_eclbase=booleans): ) ) need_eclbase = any( - (isinstance(val, (HistoryObservation, SummaryObservation)) for val in obs) + isinstance(val, HistoryObservation | SummaryObservation) for val in obs ) use_eclbase = draw(use_eclbase) if not need_eclbase else True dates = _observation_dates(obs, first_date) @@ -538,7 +538,7 @@ def sim_job(installed_jobs): def _observation_dates( observations, start_date: datetime.datetime -) -> List[datetime.datetime]: +) -> list[datetime.datetime]: """ :returns: the dates that need to exist in the refcase for ert to accept the observations diff --git a/tests/ert/unit_tests/config/egrid_generator.py b/tests/ert/unit_tests/config/egrid_generator.py index df0a252646b..912fbc191e6 100644 --- a/tests/ert/unit_tests/config/egrid_generator.py +++ b/tests/ert/unit_tests/config/egrid_generator.py @@ -1,6 +1,6 @@ from dataclasses import astuple, dataclass from enum import Enum, auto, unique -from typing import Any, List, Optional, Tuple +from typing import Any import hypothesis.strategies as st import numpy as np @@ -59,7 +59,7 @@ class GrdeclKeyword: ... return [self.field1.to_ecl(), self.field2.to_ecl] """ - def to_ecl(self) -> List[Any]: + def to_ecl(self) -> list[Any]: return [value.to_ecl() for value in astuple(self)] @@ -197,8 +197,8 @@ class GridHead: numres: int nseg: int coordinate_type: CoordinateType - lgr_start: Tuple[int, int, int] - lgr_end: Tuple[int, int, int] + lgr_start: tuple[int, int, int] + lgr_end: tuple[int, int, int] def to_ecl(self) -> np.ndarray: # The data is expected to consist of @@ -228,7 +228,7 @@ class GlobalGrid: grid_head: GridHead coord: np.ndarray zcorn: np.ndarray - actnum: Optional[np.ndarray] = None + actnum: np.ndarray | None = None def __eq__(self, other: object) -> bool: if not isinstance(other, GlobalGrid): @@ -240,7 +240,7 @@ def __eq__(self, other: object) -> bool: and np.array_equal(self.zcorn, other.zcorn) ) - def to_ecl(self) -> List[Tuple[str, Any]]: + def to_ecl(self) -> list[tuple[str, Any]]: result = [ ("GRIDHEAD", self.grid_head.to_ecl()), ("COORD ", self.coord.astype(np.float32)), @@ -268,7 +268,7 @@ class EGrid: global_grid: GlobalGrid @property - def shape(self) -> Tuple[int, int, int]: + def shape(self) -> tuple[int, int, int]: grid_head = self.global_grid.grid_head return (grid_head.num_x, grid_head.num_y, grid_head.num_z) @@ -366,7 +366,7 @@ def global_grids(draw): egrids = st.builds(EGrid, file_heads, global_grids()) -def simple_grid(dims: Tuple[int, int, int] = (2, 2, 2)): +def simple_grid(dims: tuple[int, int, int] = (2, 2, 2)): corner_size = (dims[0] + 1) * (dims[1] + 1) * 6 coord = np.zeros( shape=corner_size, diff --git a/tests/ert/unit_tests/config/observations_generator.py b/tests/ert/unit_tests/config/observations_generator.py index 9853ab2113c..0e66c7a2fd3 100644 --- a/tests/ert/unit_tests/config/observations_generator.py +++ b/tests/ert/unit_tests/config/observations_generator.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields from enum import Enum, auto -from typing import List, Optional import hypothesis.strategies as st from hypothesis import assume @@ -38,7 +37,7 @@ def __str__(self): continue if isinstance(val, Enum): result += f"{f.name.upper()} = {val.name}; " - elif isinstance(val, (float, str, int)): + elif isinstance(val, float | str | int): result += f"{f.name.upper()} = {val}; " elif isinstance(val, Observation): result += str(val) @@ -64,7 +63,7 @@ def class_name(self): @dataclass class HistoryObservation(Observation): error_mode: ErrorMode - segment: List[Segment] = field(default_factory=list) + segment: list[Segment] = field(default_factory=list) @property def class_name(self): @@ -80,10 +79,10 @@ class SummaryObservation(Observation): key: str error_min: PositiveFloat error_mode: ErrorMode - days: Optional[float] = None - hours: Optional[float] = None - restart: Optional[int] = None - date: Optional[str] = None + days: float | None = None + hours: float | None = None + restart: int | None = None + date: str | None = None @property def class_name(self): @@ -103,13 +102,13 @@ def get_date(self, start): @dataclass class GeneralObservation(Observation): data: str - date: Optional[str] = None - days: Optional[float] = None - hours: Optional[float] = None - restart: Optional[int] = None - obs_file: Optional[str] = None - value: Optional[float] = None - index_list: Optional[List[int]] = None + date: str | None = None + days: float | None = None + hours: float | None = None + restart: int | None = None + obs_file: str | None = None + value: float | None = None + index_list: list[int] | None = None def get_date(self, start): if self.date is not None: diff --git a/tests/ert/unit_tests/config/parsing/test_lark_parser.py b/tests/ert/unit_tests/config/parsing/test_lark_parser.py index f97deafb488..c1393d03914 100644 --- a/tests/ert/unit_tests/config/parsing/test_lark_parser.py +++ b/tests/ert/unit_tests/config/parsing/test_lark_parser.py @@ -3,11 +3,7 @@ import pytest -from ert.config.parsing import ( - ConfigValidationError, - init_user_config_schema, - parse, -) +from ert.config.parsing import ConfigValidationError, init_user_config_schema, parse def touch(filename): diff --git a/tests/ert/unit_tests/config/test_ert_config.py b/tests/ert/unit_tests/config/test_ert_config.py index ad7dd4b6c75..da5709049a3 100644 --- a/tests/ert/unit_tests/config/test_ert_config.py +++ b/tests/ert/unit_tests/config/test_ert_config.py @@ -1609,7 +1609,7 @@ def test_that_context_types_are_json_serializable(): with open("test.json", "w", encoding="utf-8") as f: json.dump(payload, f, cls=ContextBoolEncoder) - with open("test.json", "r", encoding="utf-8") as f: + with open("test.json", encoding="utf-8") as f: r = json.load(f) assert isinstance(r["context_bool_false"], bool) diff --git a/tests/ert/unit_tests/config/test_forward_model_data_to_json.py b/tests/ert/unit_tests/config/test_forward_model_data_to_json.py index d26d29d4ff8..83f98c9499e 100644 --- a/tests/ert/unit_tests/config/test_forward_model_data_to_json.py +++ b/tests/ert/unit_tests/config/test_forward_model_data_to_json.py @@ -4,7 +4,6 @@ import os.path import stat from textwrap import dedent -from typing import List import pytest @@ -162,8 +161,8 @@ def _generate_step( return _forward_model_step_from_config_file(config_file, name) -def empty_list_if_none(_list): - return [] if _list is None else _list +def empty_list_if_none(list_): + return [] if list_ is None else list_ def default_name_if_none(name): @@ -236,7 +235,7 @@ def generate_step_from_dict(forward_model_config): return forward_model -def set_up_forward_model(fm_steplist) -> List[ForwardModelStep]: +def set_up_forward_model(fm_steplist) -> list[ForwardModelStep]: return [generate_step_from_dict(step) for step in fm_steplist] @@ -393,7 +392,7 @@ def test_various_null_fields(fm_step_list, context): @pytest.mark.usefixtures("use_tmpdir") def test_that_values_with_brackets_are_ommitted(caplog, fm_step_list, context): - forward_model_list: List[ForwardModelStep] = set_up_forward_model(fm_step_list) + forward_model_list: list[ForwardModelStep] = set_up_forward_model(fm_step_list) forward_model_list[0].environment["ENV_VAR"] = "" run_id = "test_no_jobs_id" diff --git a/tests/ert/unit_tests/config/test_gen_data_config.py b/tests/ert/unit_tests/config/test_gen_data_config.py index f087916409a..9a2d085cd41 100644 --- a/tests/ert/unit_tests/config/test_gen_data_config.py +++ b/tests/ert/unit_tests/config/test_gen_data_config.py @@ -1,7 +1,6 @@ import os from contextlib import suppress from pathlib import Path -from typing import List import hypothesis.strategies as st import pytest @@ -18,7 +17,7 @@ ], ) @pytest.mark.usefixtures("use_tmpdir") -def test_gen_data_config(name: str, report_steps: List[int]): +def test_gen_data_config(name: str, report_steps: list[int]): gdc = GenDataConfig(keys=[name], report_steps_list=[report_steps]) assert gdc.keys == [name] assert gdc.report_steps_list[0] == sorted(report_steps) diff --git a/tests/ert/unit_tests/config/test_gen_kw_config.py b/tests/ert/unit_tests/config/test_gen_kw_config.py index adede8db68d..71127de168f 100644 --- a/tests/ert/unit_tests/config/test_gen_kw_config.py +++ b/tests/ert/unit_tests/config/test_gen_kw_config.py @@ -78,7 +78,7 @@ def test_gen_kw_config_get_priors(): f.write("KEY10 CONST 10\n") transform_function_definitions = [] - with open(parameter_file, "r", encoding="utf-8") as file: + with open(parameter_file, encoding="utf-8") as file: for item in file: items = item.split() transform_function_definitions.append( diff --git a/tests/ert/unit_tests/config/test_observations.py b/tests/ert/unit_tests/config/test_observations.py index c32fb8b3c5c..8a9ebae238b 100644 --- a/tests/ert/unit_tests/config/test_observations.py +++ b/tests/ert/unit_tests/config/test_observations.py @@ -1299,9 +1299,7 @@ def test_that_unknown_key_in_is_handled(tmpdir, observation_type): ErtConfig.from_file("config.ert") -def test_validation_of_duplicate_names( - tmpdir, -): +def test_validation_of_duplicate_names(tmpdir): with tmpdir.as_cwd(): config = dedent( """ diff --git a/tests/ert/unit_tests/config/test_parser_error_collection.py b/tests/ert/unit_tests/config/test_parser_error_collection.py index 2c1ced4ec40..d2c3ca6a073 100644 --- a/tests/ert/unit_tests/config/test_parser_error_collection.py +++ b/tests/ert/unit_tests/config/test_parser_error_collection.py @@ -2,9 +2,9 @@ import re import stat import warnings +from collections.abc import Sequence from dataclasses import dataclass from textwrap import dedent -from typing import Dict, List, Optional, Sequence, Union import pytest from hypothesis import given, strategies @@ -26,15 +26,15 @@ class FileDetail: @dataclass class ExpectedErrorInfo: filename: str = "test.ert" - line: Optional[int] = None - column: Optional[int] = None - end_column: Optional[int] = None - other_files: Optional[Dict[str, FileDetail]] = None - match: Optional[str] = None - count: Optional[int] = None + line: int | None = None + column: int | None = None + end_column: int | None = None + other_files: dict[str, FileDetail] | None = None + match: str | None = None + count: int | None = None -def write_files(files: Optional[Dict[str, Union[str, FileDetail]]] = None): +def write_files(files: dict[str, str | FileDetail] | None = None): if files is not None: for other_filename, content in files.items(): with open(other_filename, mode="w", encoding="utf-8") as fh: @@ -51,7 +51,7 @@ def write_files(files: Optional[Dict[str, Union[str, FileDetail]]] = None): def find_and_assert_errors_matching_filename( - errors: Sequence[ErrorInfo], filename: Optional[str] + errors: Sequence[ErrorInfo], filename: str | None ): matching_errors = ( [err for err in errors if err.filename is not None and filename in err.filename] @@ -68,9 +68,9 @@ def find_and_assert_errors_matching_filename( def find_and_assert_errors_matching_location( errors: Sequence[ErrorInfo], - line: Optional[int] = None, - column: Optional[int] = None, - end_column: Optional[int] = None, + line: int | None = None, + column: int | None = None, + end_column: int | None = None, ): def equals_or_expected_any(actual, expected): return True if expected is None else actual == expected @@ -83,7 +83,7 @@ def equals_or_expected_any(actual, expected): and equals_or_expected_any(x.end_column, end_column) ] - def none_to_star(val: Optional[int] = None): + def none_to_star(val: int | None = None): return "*" if val is None else val assert len(matching_errors) > 0, ( @@ -97,7 +97,7 @@ def none_to_star(val: Optional[int] = None): def find_and_assert_errors_matching_message( - errors: List[ErrorInfo], match: Optional[str] = None + errors: list[ErrorInfo], match: str | None = None ): if match is None: return errors diff --git a/tests/ert/unit_tests/config/test_read_summary.py b/tests/ert/unit_tests/config/test_read_summary.py index 9caa7ea2fcb..b6984c85ea9 100644 --- a/tests/ert/unit_tests/config/test_read_summary.py +++ b/tests/ert/unit_tests/config/test_read_summary.py @@ -477,9 +477,7 @@ def test_missing_keywords_in_smspec_raises_informative_error( read_summary(str(tmp_path / "test"), ["*"]) -def test_that_ambiguous_case_restart_raises_an_informative_error( - tmp_path, -): +def test_that_ambiguous_case_restart_raises_an_informative_error(tmp_path): (tmp_path / "test.UNSMRY").write_bytes(b"") (tmp_path / "test.FUNSMRY").write_bytes(b"") (tmp_path / "test.smspec").write_bytes(b"") diff --git a/tests/ert/unit_tests/config/test_transfer_functions.py b/tests/ert/unit_tests/config/test_transfer_functions.py index 510765ea83e..f02c65f4324 100644 --- a/tests/ert/unit_tests/config/test_transfer_functions.py +++ b/tests/ert/unit_tests/config/test_transfer_functions.py @@ -119,16 +119,16 @@ def test_that_derrf_creates_at_least_steps_or_less_distinct_values(xlist, arg): @given(nice_floats(), valid_derrf_parameters()) def test_that_derrf_corresponds_scaled_binned_normal_cdf(x, arg): """Check correspondance to normal cdf with -mu=_skew and sd=_width""" - _steps, _min, _max, _skew, _width = arg - q_values = np.linspace(start=0, stop=1, num=_steps) - q_checks = np.linspace(start=0, stop=1, num=_steps + 1)[1:] - p = norm.cdf(x, loc=-_skew, scale=_width) + steps, min_, max_, skew, width = arg + q_values = np.linspace(start=0, stop=1, num=steps) + q_checks = np.linspace(start=0, stop=1, num=steps + 1)[1:] + p = norm.cdf(x, loc=-skew, scale=width) bin_index = np.digitize(p, q_checks, right=True) expected = q_values[bin_index] # scale and ensure ok numerics - expected = _min + expected * (_max - _min) - if expected > _max or expected < _min: - np.clip(expected, _min, _max) + expected = min_ + expected * (max_ - min_) + if expected > max_ or expected < min_: + np.clip(expected, min_, max_) assert np.isclose(TransformFunction.trans_derrf(x, arg), expected) @@ -165,8 +165,8 @@ def valid_triangular_params(): @given(nice_floats(), valid_triangular_params()) def test_that_triangular_is_within_bounds(x, args): - _mode, _min, _max = args - assert _min <= TransformFunction.trans_triangular(x, [_min, _mode, _max]) <= _max + mode, min_, max_ = args + assert min_ <= TransformFunction.trans_triangular(x, [min_, mode, max_]) <= max_ @given(valid_triangular_params()) @@ -175,12 +175,12 @@ def test_mode_behavior(args): When the CDF value of x (from the normal distribution) corresponds to the relative position of the mode in the triangular distribution, the output of trans_triangular should be the mode (_mode) of the triangular distribution. """ - _mode, _min, _max = args - ymode = (_mode - _min) / (_max - _min) + mode, min_, max_ = args + ymode = (mode - min_) / (max_ - min_) x = norm.ppf(ymode) - assert np.isclose(TransformFunction.trans_triangular(x, [_min, _mode, _max]), _mode) + assert np.isclose(TransformFunction.trans_triangular(x, [min_, mode, max_]), mode) @given(valid_triangular_params()) @@ -189,12 +189,12 @@ def test_that_triangular_is_symmetric_around_mode(args): For values of x equidistant from the CDF value at the mode, the outputs should be symmetrically placed around the mode. This property holds if the triangular distribution is symmetric. """ - _mode, _min, _max = args + mode, min_, max_ = args # Ensure the triangular distribution is symmetric - _mode = (_min + _max) / 2 + mode = (min_ + max_) / 2 - ymode = (_mode - _min) / (_max - _min) + ymode = (mode - min_) / (max_ - min_) delta = ymode / 2 # Find x1 and x2 such that their CDF values are equidistant from ymode @@ -202,20 +202,20 @@ def test_that_triangular_is_symmetric_around_mode(args): x2 = norm.ppf(ymode + delta) # Calculate the corresponding triangular values - y1 = TransformFunction.trans_triangular(x1, [_min, _mode, _max]) - y2 = TransformFunction.trans_triangular(x2, [_min, _mode, _max]) + y1 = TransformFunction.trans_triangular(x1, [min_, mode, max_]) + y2 = TransformFunction.trans_triangular(x2, [min_, mode, max_]) # Check if y1 and y2 are symmetric around the mode - assert abs((_mode - y1) - (y2 - _mode)) < 1e-15 * max( - *map(abs, [x1, x2, _min, _mode, _max]) + assert abs((mode - y1) - (y2 - mode)) < 1e-15 * max( + *map(abs, [x1, x2, min_, mode, max_]) ) @given(valid_triangular_params()) def test_that_triangular_is_monotonic(args): - _mode, _min, _max = args + mode, min_, max_ = args - ymode = (_mode - _min) / (_max - _min) + ymode = (mode - min_) / (max_ - min_) delta = 0.05 # Test both sides of the mode @@ -224,8 +224,8 @@ def test_that_triangular_is_monotonic(args): x1 = norm.ppf(ymode + direction * delta) x2 = norm.ppf(ymode + direction * 2 * delta) - y1 = TransformFunction.trans_triangular(x1, [_min, _mode, _max]) - y2 = TransformFunction.trans_triangular(x2, [_min, _mode, _max]) + y1 = TransformFunction.trans_triangular(x1, [min_, mode, max_]) + y2 = TransformFunction.trans_triangular(x2, [min_, mode, max_]) # Assert monotonicity if direction == -1: diff --git a/tests/ert/unit_tests/conftest.py b/tests/ert/unit_tests/conftest.py index accbc117075..a78d0072791 100644 --- a/tests/ert/unit_tests/conftest.py +++ b/tests/ert/unit_tests/conftest.py @@ -1,6 +1,5 @@ import os import sys -from typing import List, Optional import pytest @@ -52,8 +51,8 @@ def run_args(run_paths): def func( ert_config: ErtConfig, ensemble: Ensemble, - active_realizations: Optional[int] = None, - ) -> List[RunArg]: + active_realizations: int | None = None, + ) -> list[RunArg]: active_realizations = ( ert_config.model_config.num_realizations if active_realizations is None diff --git a/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py b/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py index d0426f8e118..3088ed0a131 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py +++ b/tests/ert/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py @@ -34,8 +34,8 @@ async def _run_server(): class TestEnsemble(Ensemble): __test__ = False - def __init__(self, _iter, reals, fm_steps, id_): - self.iter = _iter + def __init__(self, iter_, reals, fm_steps, id_): + self.iter = iter_ self.test_reals = reals self.fm_steps = fm_steps diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_scheduler.py b/tests/ert/unit_tests/ensemble_evaluator/test_scheduler.py index 8fbfb0dd47f..0b981b1cbce 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_scheduler.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_scheduler.py @@ -5,11 +5,7 @@ import pytest -from _ert.events import ( - EESnapshot, - EESnapshotUpdate, - ForwardModelStepChecksum, -) +from _ert.events import EESnapshot, EESnapshotUpdate, ForwardModelStepChecksum from ert.ensemble_evaluator import EnsembleEvaluator, Monitor, identifiers, state from ert.ensemble_evaluator.config import EvaluatorServerConfig diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_snapshot.py b/tests/ert/unit_tests/ensemble_evaluator/test_snapshot.py index f0492b1bb92..f3122b53a3f 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_snapshot.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_snapshot.py @@ -7,10 +7,7 @@ RealizationSuccess, ) from ert.ensemble_evaluator import state -from ert.ensemble_evaluator.snapshot import ( - EnsembleSnapshot, - FMStepSnapshot, -) +from ert.ensemble_evaluator.snapshot import EnsembleSnapshot, FMStepSnapshot from tests.ert import SnapshotBuilder diff --git a/tests/ert/unit_tests/forward_model_runner/test_file_reporter.py b/tests/ert/unit_tests/forward_model_runner/test_file_reporter.py index d64af04d83c..a4db3e1fb25 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_file_reporter.py +++ b/tests/ert/unit_tests/forward_model_runner/test_file_reporter.py @@ -30,9 +30,9 @@ def test_report_with_init_message_argument(reporter): r.report(Init([fmstep1], 1, 19)) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, encoding="utf-8") as f: assert "Current host" in f.readline(), "STATUS file missing expected value" - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, encoding="utf-8") as f: content = "".join(f.readlines()) assert '"name": "fmstep1"' in content, "status.json missing fmstep1" assert '"status": "Waiting"' in content, "status.json missing Waiting status" @@ -56,14 +56,14 @@ def test_report_with_successful_start_message_argument(reporter): reporter.report(msg) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, encoding="utf-8") as f: assert "fmstep1" in f.readline(), "STATUS file missing fmstep1" - with open(LOG_file, "r", encoding="utf-8") as f: + with open(LOG_file, encoding="utf-8") as f: assert ( "Calling: /bin/sh --foo 1 --bar 2" in f.readline() ), """JOB_LOG file missing executable and arguments""" - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, encoding="utf-8") as f: content = "".join(f.readlines()) assert '"status": "Running"' in content, "status.json missing Running status" assert '"start_time": null' not in content, "start_time not set" @@ -76,11 +76,11 @@ def test_report_with_failed_start_message_argument(reporter): reporter.report(msg) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, encoding="utf-8") as f: assert ( "EXIT: -10/massive_failure" in f.readline() ), "STATUS file missing EXIT message" - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, encoding="utf-8") as f: content = "".join(f.readlines()) assert '"status": "Failure"' in content, "status.json missing Failure status" assert ( @@ -98,7 +98,7 @@ def test_report_with_successful_exit_message_argument(reporter): reporter.report(msg) - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, encoding="utf-8") as f: content = "".join(f.readlines()) assert '"status": "Success"' in content, "status.json missing Success status" @@ -112,9 +112,9 @@ def test_report_with_failed_exit_message_argument(reporter): reporter.report(msg) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, encoding="utf-8") as f: assert "EXIT: 1/massive_failure" in f.readline() - with open(ERROR_file, "r", encoding="utf-8") as f: + with open(ERROR_file, encoding="utf-8") as f: content = "".join(f.readlines()) assert "fmstep1" in content, "ERROR file missing fmstep" assert ( @@ -123,7 +123,7 @@ def test_report_with_failed_exit_message_argument(reporter): assert ( "stderr: Not redirected" in content ), "ERROR had invalid stderr information" - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, encoding="utf-8") as f: content = "".join(f.readlines()) assert '"status": "Failure"' in content, "status.json missing Failure status" assert ( @@ -142,7 +142,7 @@ def test_report_with_running_message_argument(reporter): reporter.report(msg) - with open(STATUS_json, "r", encoding="utf-8") as f: + with open(STATUS_json, encoding="utf-8") as f: content = "".join(f.readlines()) assert '"status": "Running"' in content, "status.json missing status" assert ( @@ -178,7 +178,7 @@ def test_dump_error_file_with_stderr(reporter): "massive_failure", ) - with open(ERROR_file, "r", encoding="utf-8") as f: + with open(ERROR_file, encoding="utf-8") as f: content = "".join(f.readlines()) assert "E_MASSIVE_FAILURE" in content, "ERROR file missing stderr content" assert "" in content, "ERROR missing stderr_file part" @@ -226,7 +226,7 @@ def test_status_file_is_correct(reporter): f"EXIT: {exited_j_2.exit_code}/{exited_j_2.error_message}\n" ) - with open(STATUS_file, "r", encoding="utf-8") as f: + with open(STATUS_file, encoding="utf-8") as f: for expected in [ "Current host", expected_j1_line, diff --git a/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner.py b/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner.py index dc4e9103b40..4fcdcca033b 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner.py +++ b/tests/ert/unit_tests/forward_model_runner/test_forward_model_runner.py @@ -245,7 +245,7 @@ def test_exec_env(): fptr, ) - with open("jobs.json", "r", encoding="utf-8") as f: + with open("jobs.json", encoding="utf-8") as f: jobs_json = json.load(f) for msg in list(ForwardModelRunner(jobs_json).run([])): @@ -286,7 +286,7 @@ def test_env_var_available_inside_step_context(): fptr, ) - with open("jobs.json", "r", encoding="utf-8") as f: + with open("jobs.json", encoding="utf-8") as f: jobs_json = json.load(f) # Check ENV variable not available outside of step context @@ -336,7 +336,7 @@ def test_default_env_variables_available_inside_fm_step_context(): fptr, ) - with open("jobs.json", "r", encoding="utf-8") as f: + with open("jobs.json", encoding="utf-8") as f: jobs_json = json.load(f) # Check default ENV variable not available outside of step context diff --git a/tests/ert/unit_tests/forward_model_runner/test_forward_model_step.py b/tests/ert/unit_tests/forward_model_runner/test_forward_model_step.py index bd475e8734e..6b93c2ecf02 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_forward_model_step.py +++ b/tests/ert/unit_tests/forward_model_runner/test_forward_model_step.py @@ -141,7 +141,7 @@ def max_memory_per_subprocess_layer(layers: int) -> int: # comparing the memory used with different amounts of forks done. # subtract a little bit (* 0.9) due to natural variance in memory used # when running the program. - memory_per_numbers_list = sys.getsizeof(int(0)) * blobsize * 0.90 + memory_per_numbers_list = sys.getsizeof(0) * blobsize * 0.90 max_seens = [max_memory_per_subprocess_layer(layers) for layers in range(3)] assert max_seens[0] + memory_per_numbers_list < max_seens[1] diff --git a/tests/ert/unit_tests/gui/simulation/view/test_legend.py b/tests/ert/unit_tests/gui/simulation/view/test_legend.py index 6c6f6d0771d..40730621c2f 100644 --- a/tests/ert/unit_tests/gui/simulation/view/test_legend.py +++ b/tests/ert/unit_tests/gui/simulation/view/test_legend.py @@ -1,5 +1,3 @@ -from typing import Dict - import hypothesis.strategies as st import pytest from hypothesis import HealthCheck, given, settings @@ -18,7 +16,7 @@ ) ) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) -def test_marker_label_text_correct(qtbot, status: Dict[str, int]): +def test_marker_label_text_correct(qtbot, status: dict[str, int]): realization_count = sum(status.values()) progress_widget = ProgressWidget() qtbot.addWidget(progress_widget) @@ -44,7 +42,7 @@ def test_marker_label_text_correct(qtbot, status: Dict[str, int]): ) ) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) -def test_progress_state_width_correct(qtbot, status: Dict[str, int]): +def test_progress_state_width_correct(qtbot, status: dict[str, int]): realization_count = sum(status.values()) progress_widget = ProgressWidget() qtbot.addWidget(progress_widget) diff --git a/tests/ert/unit_tests/gui/tools/plot/test_plot_window.py b/tests/ert/unit_tests/gui/tools/plot/test_plot_window.py index 4191704c896..0d23a9d4dcb 100644 --- a/tests/ert/unit_tests/gui/tools/plot/test_plot_window.py +++ b/tests/ert/unit_tests/gui/tools/plot/test_plot_window.py @@ -1,9 +1,6 @@ from pytestqt.qtbot import QtBot from qtpy.QtCore import Qt -from qtpy.QtWidgets import ( - QApplication, - QPushButton, -) +from qtpy.QtWidgets import QApplication, QPushButton from ert.gui.tools.plot.plot_window import create_error_dialog diff --git a/tests/ert/unit_tests/plugins/test_plugin_manager.py b/tests/ert/unit_tests/plugins/test_plugin_manager.py index 013b61af49c..35c4185e224 100644 --- a/tests/ert/unit_tests/plugins/test_plugin_manager.py +++ b/tests/ert/unit_tests/plugins/test_plugin_manager.py @@ -11,9 +11,7 @@ from ert.config import ErtConfig from ert.plugins import ErtPluginManager, plugin from tests.ert.unit_tests.plugins import dummy_plugins -from tests.ert.unit_tests.plugins.dummy_plugins import ( - DummyFMStep, -) +from tests.ert.unit_tests.plugins.dummy_plugins import DummyFMStep def test_no_plugins(): diff --git a/tests/ert/unit_tests/resources/test_shell.py b/tests/ert/unit_tests/resources/test_shell.py index 74a9cb7989e..993da23ccb7 100644 --- a/tests/ert/unit_tests/resources/test_shell.py +++ b/tests/ert/unit_tests/resources/test_shell.py @@ -217,12 +217,12 @@ def test_move_file_into_folder_file_exists(): with open("file", "w", encoding="utf-8") as f: f.write("new") - with open("dst_folder/file", "r", encoding="utf-8") as f: + with open("dst_folder/file", encoding="utf-8") as f: content = f.read() assert content == "old" move_file("file", "dst_folder") - with open("dst_folder/file", "r", encoding="utf-8") as f: + with open("dst_folder/file", encoding="utf-8") as f: content = f.read() assert content == "new" @@ -237,7 +237,7 @@ def test_move_pathfile_into_folder(): f.write("stuff") move_file("source1/source2/file", "dst_folder") - with open("dst_folder/file", "r", encoding="utf-8") as f: + with open("dst_folder/file", encoding="utf-8") as f: content = f.read() assert content == "stuff" @@ -255,7 +255,7 @@ def test_move_pathfile_into_folder_file_exists(): f.write("garbage") move_file("source1/source2/file", "dst_folder") - with open("dst_folder/file", "r", encoding="utf-8") as f: + with open("dst_folder/file", encoding="utf-8") as f: content = f.read() assert content == "stuff" @@ -496,7 +496,7 @@ def test_careful_copy_file(): f.write("hallo") careful_copy_file("file1", "file2") - with open("file2", "r", encoding="utf-8") as f: + with open("file2", encoding="utf-8") as f: assert f.readline() == "hallo" print(careful_copy_file("file1", "file3")) diff --git a/tests/ert/unit_tests/resources/test_templating.py b/tests/ert/unit_tests/resources/test_templating.py index 47a77c64cb1..c65868dd195 100644 --- a/tests/ert/unit_tests/resources/test_templating.py +++ b/tests/ert/unit_tests/resources/test_templating.py @@ -153,7 +153,7 @@ def test_template_multiple_input(): render_template(["second.json", "third.json"], "template", "out_file") - with open("out_file", "r", encoding="utf-8") as parameter_file: + with open("out_file", encoding="utf-8") as parameter_file: expected_output = ( "FILENAME\n" + "F1 1999.22\n" + "OTH 1400\n" + "OTH_TEST 3000.22" ) @@ -186,7 +186,7 @@ def test_no_parameters_json(): "out_file", ) - with open("out_file", "r", encoding="utf-8") as parameter_file: + with open("out_file", encoding="utf-8") as parameter_file: expected_output = ( "FILENAME\n" + "F1 1999.22\n" + "OTH 1400\n" + "OTH_TEST 3000.22" ) @@ -226,7 +226,7 @@ def test_template_executable(): subprocess.call(template_render_exec + params, shell=True, stdout=subprocess.PIPE) - with open("out_file", "r", encoding="utf-8") as parameter_file: + with open("out_file", encoding="utf-8") as parameter_file: expected_output = "FILENAME\n" + "F1 1999.22\n" + "F2 200" assert parameter_file.read() == expected_output diff --git a/tests/ert/unit_tests/scheduler/bin/bhist.py b/tests/ert/unit_tests/scheduler/bin/bhist.py index 57d8358e247..9016be99fde 100644 --- a/tests/ert/unit_tests/scheduler/bin/bhist.py +++ b/tests/ert/unit_tests/scheduler/bin/bhist.py @@ -3,7 +3,6 @@ import time from pathlib import Path from textwrap import dedent -from typing import List, Optional from pydantic import BaseModel @@ -31,7 +30,7 @@ def get_parser() -> argparse.ArgumentParser: return parser -def bhist_formatter(jobstats: List[Job]) -> str: +def bhist_formatter(jobstats: list[Job]) -> str: string = "Summary of time in seconds spent in various states:\n" string += "JOBID USER JOB_NAME PEND PSUSP RUN USUSP SSUSP UNKWN TOTAL\n" for job in jobstats: @@ -45,7 +44,7 @@ def bhist_formatter(jobstats: List[Job]) -> str: return string -def bhist_long_formatter(jobstats: List[Job]) -> str: +def bhist_long_formatter(jobstats: list[Job]) -> str: """ This function outputs stub data entirely independent from the input. """ @@ -73,7 +72,7 @@ def bhist_long_formatter(jobstats: List[Job]) -> str: return f"{50*'-'}".join(formatted_job_outputs) -def read(path: Path, default: Optional[str] = None) -> Optional[str]: +def read(path: Path, default: str | None = None) -> str | None: return path.read_text().strip() if path.exists() else default @@ -82,7 +81,7 @@ def main() -> None: jobs_path = Path(os.getenv("PYTEST_TMP_PATH", ".")) / "mock_jobs" - jobs_output: List[Job] = [] + jobs_output: list[Job] = [] for job in args.jobs: job_name: str = read(jobs_path / f"{job}.name") or "_" assert job_name is not None diff --git a/tests/ert/unit_tests/scheduler/bin/bjobs.py b/tests/ert/unit_tests/scheduler/bin/bjobs.py index 7d7a636df96..e7c426adc57 100644 --- a/tests/ert/unit_tests/scheduler/bin/bjobs.py +++ b/tests/ert/unit_tests/scheduler/bin/bjobs.py @@ -1,7 +1,7 @@ import argparse import os from pathlib import Path -from typing import List, Literal, Optional +from typing import Literal from pydantic import BaseModel @@ -26,11 +26,11 @@ def get_parser() -> argparse.ArgumentParser: return parser -def bjobs_formatter(jobstats: List[Job]) -> str: +def bjobs_formatter(jobstats: list[Job]) -> str: return "".join([f"{job.job_id}^{job.job_state}^-\n" for job in jobstats]) -def read(path: Path, default: Optional[str] = None) -> Optional[str]: +def read(path: Path, default: str | None = None) -> str | None: return path.read_text().strip() if path.exists() else default @@ -45,7 +45,7 @@ def main() -> None: print(returncode) return - jobs_output: List[Job] = [] + jobs_output: list[Job] = [] for job in args.jobs: pid = read(jobs_path / f"{job}.pid") returncode = read(jobs_path / f"{job}.returncode") diff --git a/tests/ert/unit_tests/scheduler/bin/qstat.py b/tests/ert/unit_tests/scheduler/bin/qstat.py index c14b0689b1c..c2b1e1ae0cc 100644 --- a/tests/ert/unit_tests/scheduler/bin/qstat.py +++ b/tests/ert/unit_tests/scheduler/bin/qstat.py @@ -5,7 +5,7 @@ import random from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any QSTAT_HEADER = ( "Job id Name User Time Use S Queue\n" @@ -13,7 +13,7 @@ ) -def read(path: Path, default: Optional[str] = None) -> Optional[str]: +def read(path: Path, default: str | None = None) -> str | None: return path.read_text().strip() if path.exists() else default @@ -54,7 +54,7 @@ def main() -> None: elif pid is not None: state = "R" - info: Dict[str, Any] = { + info: dict[str, Any] = { "Job_Name": name, "job_state": state, } diff --git a/tests/ert/unit_tests/scheduler/bin/sacct.py b/tests/ert/unit_tests/scheduler/bin/sacct.py index 36ce573eeff..d2afc1e6819 100644 --- a/tests/ert/unit_tests/scheduler/bin/sacct.py +++ b/tests/ert/unit_tests/scheduler/bin/sacct.py @@ -9,7 +9,7 @@ import glob import os from pathlib import Path -from typing import Literal, Optional +from typing import Literal JobState = Literal["PENDING", "RUNNING", "COMPLETED", "FAILED", "CANCELLED"] @@ -24,7 +24,7 @@ def get_parser() -> argparse.ArgumentParser: return parser -def read(path: Path, default: Optional[str] = None) -> Optional[str]: +def read(path: Path, default: str | None = None) -> str | None: return path.read_text().strip() if path.exists() else default diff --git a/tests/ert/unit_tests/scheduler/bin/scontrol.py b/tests/ert/unit_tests/scheduler/bin/scontrol.py index cce4b99ae9e..9f9d7dfc568 100644 --- a/tests/ert/unit_tests/scheduler/bin/scontrol.py +++ b/tests/ert/unit_tests/scheduler/bin/scontrol.py @@ -9,7 +9,7 @@ import glob import os from pathlib import Path -from typing import Literal, Optional +from typing import Literal JobState = Literal["PENDING", "RUNNING", "COMPLETED", "FAILED", "CANCELLED"] @@ -22,7 +22,7 @@ def get_parser() -> argparse.ArgumentParser: return parser -def read(path: Path, default: Optional[str] = None) -> Optional[str]: +def read(path: Path, default: str | None = None) -> str | None: return path.read_text().strip() if path.exists() else default diff --git a/tests/ert/unit_tests/scheduler/bin/squeue.py b/tests/ert/unit_tests/scheduler/bin/squeue.py index 4ac8ff9b878..8b09369cf37 100644 --- a/tests/ert/unit_tests/scheduler/bin/squeue.py +++ b/tests/ert/unit_tests/scheduler/bin/squeue.py @@ -9,7 +9,7 @@ import glob import os from pathlib import Path -from typing import Literal, Optional +from typing import Literal JobState = Literal["PENDING", "RUNNING", "COMPLETED", "FAILED", "CANCELLED"] @@ -27,7 +27,7 @@ def get_parser() -> argparse.ArgumentParser: return parser -def read(path: Path, default: Optional[str] = None) -> Optional[str]: +def read(path: Path, default: str | None = None) -> str | None: return path.read_text().strip() if path.exists() else default diff --git a/tests/ert/unit_tests/scheduler/conftest.py b/tests/ert/unit_tests/scheduler/conftest.py index dc8b7d41a5e..e6ce3ebef58 100644 --- a/tests/ert/unit_tests/scheduler/conftest.py +++ b/tests/ert/unit_tests/scheduler/conftest.py @@ -3,8 +3,9 @@ import asyncio import os import sys +from collections.abc import Coroutine from pathlib import Path -from typing import Any, Coroutine, Literal +from typing import Any, Literal import pytest diff --git a/tests/ert/unit_tests/scheduler/test_job.py b/tests/ert/unit_tests/scheduler/test_job.py index ef3e0307394..318f03dab4d 100644 --- a/tests/ert/unit_tests/scheduler/test_job.py +++ b/tests/ert/unit_tests/scheduler/test_job.py @@ -3,7 +3,6 @@ import shutil from functools import partial from pathlib import Path -from typing import List from unittest.mock import AsyncMock, MagicMock import pytest @@ -56,7 +55,7 @@ def realization(): async def assert_scheduler_events( - scheduler: Scheduler, expected_job_events: List[JobState] + scheduler: Scheduler, expected_job_events: list[JobState] ) -> None: for expected_job_event in expected_job_events: assert ( diff --git a/tests/ert/unit_tests/scheduler/test_lsf_driver.py b/tests/ert/unit_tests/scheduler/test_lsf_driver.py index 84ae5ce39ac..088575ad618 100644 --- a/tests/ert/unit_tests/scheduler/test_lsf_driver.py +++ b/tests/ert/unit_tests/scheduler/test_lsf_driver.py @@ -7,10 +7,11 @@ import stat import string import time +from collections.abc import Collection from contextlib import ExitStack as does_not_raise from pathlib import Path from textwrap import dedent -from typing import Collection, List, Optional, get_args, get_type_hints +from typing import get_args, get_type_hints from unittest.mock import AsyncMock import pytest @@ -99,7 +100,7 @@ async def test_exit_codes(tmp_path_factory, bjobs_script, bhist_script, exit_cod exit_code=st.integers(min_value=1, max_value=254), ) async def test_events_produced_from_jobstate_updates( - tmp_path_factory, jobstate_sequence: List[str], exit_code: int + tmp_path_factory, jobstate_sequence: list[str], exit_code: int ): tmp_path = tmp_path_factory.mktemp("bjobs_mock") mocked_bjobs = tmp_path / "bjobs" @@ -293,9 +294,9 @@ async def test_faulty_bsub_produces_error_log(monkeypatch, tmp_path): bin_path.mkdir() monkeypatch.setenv("PATH", f"{bin_path}:{os.environ['PATH']}") - _out = "THIS_IS_OUTPUT" - _err = "THIS_IS_ERROR" - bsub_script = f"echo {_out} && echo {_err} >&2; exit 1" + out = "THIS_IS_OUTPUT" + err = "THIS_IS_ERROR" + bsub_script = f"echo {out} && echo {err} >&2; exit 1" bsub_path = bin_path / "bsub" bsub_path.write_text(f"#!/bin/sh\n{bsub_script}") bsub_path.chmod(bsub_path.stat().st_mode | stat.S_IEXEC) @@ -304,7 +305,7 @@ async def test_faulty_bsub_produces_error_log(monkeypatch, tmp_path): with pytest.raises(RuntimeError): await driver.submit(0, "sleep") assert ( - f'failed with exit code 1, output: "{_out}", and error: "{_err}"' + f'failed with exit code 1, output: "{out}", and error: "{err}"' in driver._job_error_message_by_iens[0] ) @@ -782,9 +783,9 @@ async def test_that_bsub_will_retry_and_succeed( ], ) def test_build_resource_requirement_string( - resource_requirement: Optional[str], - exclude_hosts: List[str], - realization_memory: Optional[int], + resource_requirement: str | None, + exclude_hosts: list[str], + realization_memory: int | None, expected_string: str, ): assert ( @@ -1070,9 +1071,9 @@ async def test_lsf_can_retrieve_stdout_and_stderr( os.chdir(tmp_path) driver = LsfDriver() num_written_characters = 600 - _out = generate_random_text(num_written_characters) - _err = generate_random_text(num_written_characters) - await driver.submit(0, "sh", "-c", f"echo {_out} && echo {_err} >&2", name=job_name) + out = generate_random_text(num_written_characters) + err = generate_random_text(num_written_characters) + await driver.submit(0, "sh", "-c", f"echo {out} && echo {err} >&2", name=job_name) await poll(driver, {0}) message = driver.read_stdout_and_stderr_files( runpath=".", @@ -1091,9 +1092,9 @@ async def test_lsf_cannot_retrieve_stdout_and_stderr(tmp_path, job_name): os.chdir(tmp_path) driver = LsfDriver() num_written_characters = 600 - _out = generate_random_text(num_written_characters) - _err = generate_random_text(num_written_characters) - await driver.submit(0, "sh", "-c", f"echo {_out} && echo {_err} >&2", name=job_name) + out = generate_random_text(num_written_characters) + err = generate_random_text(num_written_characters) + await driver.submit(0, "sh", "-c", f"echo {out} && echo {err} >&2", name=job_name) await poll(driver, {0}) # let's remove the output files os.remove(job_name + ".LSF-stderr") diff --git a/tests/ert/unit_tests/scheduler/test_openpbs_driver.py b/tests/ert/unit_tests/scheduler/test_openpbs_driver.py index 672974287c7..41f99b61c2f 100644 --- a/tests/ert/unit_tests/scheduler/test_openpbs_driver.py +++ b/tests/ert/unit_tests/scheduler/test_openpbs_driver.py @@ -8,7 +8,6 @@ from functools import partial from pathlib import Path from textwrap import dedent -from typing import Dict, List import pytest from hypothesis import given @@ -48,7 +47,7 @@ @given(st.lists(st.sampled_from(JOB_STATES))) -async def test_events_produced_from_jobstate_updates(jobstate_sequence: List[str]): +async def test_events_produced_from_jobstate_updates(jobstate_sequence: list[str]): # Determine what to expect from the sequence: started = False finished = False @@ -121,7 +120,7 @@ def capturing_qsub(monkeypatch, tmp_path): qsub_path.chmod(qsub_path.stat().st_mode | stat.S_IEXEC) -def parse_resource_string(qsub_args: str) -> Dict[str, str]: +def parse_resource_string(qsub_args: str) -> dict[str, str]: resources = {} args = shlex.split(qsub_args) diff --git a/tests/ert/unit_tests/scheduler/test_scheduler.py b/tests/ert/unit_tests/scheduler/test_scheduler.py index 3d89b924fc0..e96074f493b 100644 --- a/tests/ert/unit_tests/scheduler/test_scheduler.py +++ b/tests/ert/unit_tests/scheduler/test_scheduler.py @@ -6,7 +6,6 @@ import time from functools import partial from pathlib import Path -from typing import List import pytest @@ -251,7 +250,7 @@ async def wait(): @pytest.mark.flaky(reruns=3) @pytest.mark.parametrize("max_running", [0, 1, 2, 10]) async def test_max_running(max_running, mock_driver, storage, tmp_path): - runs: List[bool] = [] + runs: list[bool] = [] async def wait(): nonlocal runs @@ -496,7 +495,7 @@ async def test_submit_sleep( tmp_path, mock_driver, ): - run_start_times: List[float] = [] + run_start_times: list[float] = [] async def wait(): nonlocal run_start_times @@ -541,7 +540,7 @@ async def wait(): async def test_submit_sleep_with_max_running( submit_sleep, realization_max_runtime, max_running, storage, tmp_path, mock_driver ): - run_start_times: List[float] = [] + run_start_times: list[float] = [] async def wait(): nonlocal run_start_times diff --git a/tests/ert/unit_tests/scheduler/test_slurm_driver.py b/tests/ert/unit_tests/scheduler/test_slurm_driver.py index e9fb263b311..71da765a5b3 100644 --- a/tests/ert/unit_tests/scheduler/test_slurm_driver.py +++ b/tests/ert/unit_tests/scheduler/test_slurm_driver.py @@ -267,9 +267,9 @@ async def test_faulty_sbatch_produces_error_log(monkeypatch, tmp_path): bin_path.mkdir() monkeypatch.setenv("PATH", f"{bin_path}:{os.environ['PATH']}") - _out = "THIS_IS_OUTPUT" - _err = "THIS_IS_ERROR" - sbatch_script = f"echo {_out} && echo {_err} >&2; exit 1" + out = "THIS_IS_OUTPUT" + err = "THIS_IS_ERROR" + sbatch_script = f"echo {out} && echo {err} >&2; exit 1" sbatch_path = bin_path / "sbatch" sbatch_path.write_text(f"#!/bin/sh\n{sbatch_script}") sbatch_path.chmod(sbatch_path.stat().st_mode | stat.S_IEXEC) @@ -278,7 +278,7 @@ async def test_faulty_sbatch_produces_error_log(monkeypatch, tmp_path): with pytest.raises(RuntimeError): await driver.submit(0, "sleep") assert ( - f'failed with exit code 1, output: "{_out}", and error: "{_err}"' + f'failed with exit code 1, output: "{out}", and error: "{err}"' in driver._job_error_message_by_iens[0] ) @@ -335,9 +335,9 @@ async def test_slurm_can_retrieve_stdout_and_stderr( os.chdir(tmp_path) driver = SlurmDriver() num_written_characters = 600 - _out = generate_random_text(num_written_characters) - _err = generate_random_text(num_written_characters) - await driver.submit(0, "sh", "-c", f"echo {_out} && echo {_err} >&2", name=job_name) + out = generate_random_text(num_written_characters) + err = generate_random_text(num_written_characters) + await driver.submit(0, "sh", "-c", f"echo {out} && echo {err} >&2", name=job_name) await poll(driver, {0}) message = driver.read_stdout_and_stderr_files( runpath=".", diff --git a/tests/ert/unit_tests/storage/create_runpath.py b/tests/ert/unit_tests/storage/create_runpath.py index a22a163aa92..5d6ac1d56bb 100644 --- a/tests/ert/unit_tests/storage/create_runpath.py +++ b/tests/ert/unit_tests/storage/create_runpath.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple - from ert.config import ErtConfig from ert.enkf_main import create_run_path, sample_prior from ert.run_arg import create_run_arguments @@ -12,10 +10,10 @@ def create_runpath( config, active_mask=None, *, - ensemble: Optional[Ensemble] = None, + ensemble: Ensemble | None = None, iteration=0, - random_seed: Optional[int] = 1234, -) -> Tuple[ErtConfig, Ensemble]: + random_seed: int | None = 1234, +) -> tuple[ErtConfig, Ensemble]: active_mask = [True] if active_mask is None else active_mask ert_config = ErtConfig.from_file(config) diff --git a/tests/ert/unit_tests/storage/test_local_storage.py b/tests/ert/unit_tests/storage/test_local_storage.py index 3bcf6a860ac..618879acc9b 100644 --- a/tests/ert/unit_tests/storage/test_local_storage.py +++ b/tests/ert/unit_tests/storage/test_local_storage.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Dict, List +from typing import Any from unittest.mock import MagicMock, PropertyMock, patch from uuid import UUID @@ -56,7 +56,7 @@ def test_create_experiment(tmp_path): assert (experiment_path / experiment._parameter_file).exists() assert (experiment_path / experiment._responses_file).exists() - with open(experiment_path / "index.json", encoding="utf-8", mode="r") as f: + with open(experiment_path / "index.json", encoding="utf-8") as f: index = json.load(f) assert index["id"] == str(experiment.id) assert index["name"] == "test-experiment" @@ -560,7 +560,7 @@ def _inner(params): @st.composite -def fields(draw, egrid, num_fields=small_ints) -> List[Field]: +def fields(draw, egrid, num_fields=small_ints) -> list[Field]: grid_file, grid = egrid nx, ny, nz = grid.shape return [ @@ -640,18 +640,18 @@ def test_write_transaction_overwrites(tmp_path): @dataclass class Ensemble: uuid: UUID - parameter_values: Dict[str, Any] = field(default_factory=dict) - response_values: Dict[str, Any] = field(default_factory=dict) - failure_messages: Dict[int, str] = field(default_factory=dict) + parameter_values: dict[str, Any] = field(default_factory=dict) + response_values: dict[str, Any] = field(default_factory=dict) + failure_messages: dict[int, str] = field(default_factory=dict) @dataclass class Experiment: uuid: UUID - ensembles: Dict[UUID, Ensemble] = field(default_factory=dict) - parameters: List[ParameterConfig] = field(default_factory=list) - responses: List[ResponseConfig] = field(default_factory=list) - observations: Dict[str, polars.DataFrame] = field(default_factory=dict) + ensembles: dict[UUID, Ensemble] = field(default_factory=dict) + parameters: list[ParameterConfig] = field(default_factory=list) + responses: list[ResponseConfig] = field(default_factory=list) + observations: dict[str, polars.DataFrame] = field(default_factory=dict) @settings(max_examples=250) @@ -687,7 +687,7 @@ def __init__(self): self.tmpdir = tempfile.mkdtemp(prefix="StatefulStorageTest") self.storage = open_storage(self.tmpdir + "/storage/", "w") note(f"storage path is: {self.storage.path}") - self.model: Dict[UUID, Experiment] = {} + self.model: dict[UUID, Experiment] = {} assert list(self.storage.ensembles) == [] # Realization to save/delete params/responses @@ -741,8 +741,8 @@ def reopen(self): ) def create_experiment( self, - parameters: List[ParameterConfig], - responses: List[ResponseConfig], + parameters: list[ParameterConfig], + responses: list[ResponseConfig], obs: EnkfObs, ): experiment_id = self.storage.create_experiment( diff --git a/tests/ert/unit_tests/storage/test_mode.py b/tests/ert/unit_tests/storage/test_mode.py index a19cc9022f2..297196c6843 100644 --- a/tests/ert/unit_tests/storage/test_mode.py +++ b/tests/ert/unit_tests/storage/test_mode.py @@ -1,11 +1,6 @@ import pytest -from ert.storage.mode import ( - BaseMode, - Mode, - ModeError, - require_write, -) +from ert.storage.mode import BaseMode, Mode, ModeError, require_write class SomeClass(BaseMode): diff --git a/tests/ert/unit_tests/test_run_path_creation.py b/tests/ert/unit_tests/test_run_path_creation.py index e4614baa626..cce5202af5f 100644 --- a/tests/ert/unit_tests/test_run_path_creation.py +++ b/tests/ert/unit_tests/test_run_path_creation.py @@ -484,7 +484,7 @@ def test_write_snakeoil_runpath_file(snake_oil_case, storage, itr): run_paths = Runpaths( jobname_format=jobname_fmt, runpath_format=runpath_fmt, - filename=str("a_file_name"), + filename="a_file_name", substitutions=global_substitutions, ) sample_prior(prior_ensemble, [i for i, active in enumerate(mask) if active]) @@ -523,7 +523,7 @@ def test_write_snakeoil_runpath_file(snake_oil_case, storage, itr): ] exp_runpaths = list(map(os.path.realpath, exp_runpaths)) - with open(runpath_list_path, "r", encoding="utf-8") as f: + with open(runpath_list_path, encoding="utf-8") as f: dumped_runpaths = list( zip(*[line.split() for line in f.readlines()], strict=False) )[1] diff --git a/tests/ert/unit_tests/test_tracking.py b/tests/ert/unit_tests/test_tracking.py index cf031d8f1f6..ab9dd76d41b 100644 --- a/tests/ert/unit_tests/test_tracking.py +++ b/tests/ert/unit_tests/test_tracking.py @@ -4,7 +4,6 @@ import re from argparse import ArgumentParser from pathlib import Path -from typing import Dict import pytest from jsonpath_ng import parse @@ -199,7 +198,7 @@ def test_tracking( ) thread.start() - snapshots: Dict[str, EnsembleSnapshot] = {} + snapshots: dict[str, EnsembleSnapshot] = {} thread.join() @@ -300,7 +299,7 @@ def test_setting_env_context_during_run( expected = ["_ERT_SIMULATION_MODE", "_ERT_EXPERIMENT_ID", "_ERT_ENSEMBLE_ID"] for event, environment in zip(queue.events, queue.environment, strict=False): - if isinstance(event, (FullSnapshotEvent, SnapshotUpdateEvent)): + if isinstance(event, FullSnapshotEvent | SnapshotUpdateEvent): for key in expected: assert key in environment assert environment.get("_ERT_SIMULATION_MODE") == mode @@ -379,7 +378,7 @@ def test_run_information_present_as_env_var_in_fm_context( # Check run information in job environment for path in model.paths: - with open(Path(path) / "jobs.json", "r", encoding="utf-8") as f: + with open(Path(path) / "jobs.json", encoding="utf-8") as f: jobs_data = json.load(f) for key in expected: assert key in jobs_data["global_environment"] diff --git a/tests/ert/unit_tests/workflow_runner/test_workflow.py b/tests/ert/unit_tests/workflow_runner/test_workflow.py index 547ca2f2eaf..c6daa4b550b 100644 --- a/tests/ert/unit_tests/workflow_runner/test_workflow.py +++ b/tests/ert/unit_tests/workflow_runner/test_workflow.py @@ -59,10 +59,10 @@ def test_workflow_run(): WorkflowRunner(workflow).run_blocking() - with open("dump1", "r", encoding="utf-8") as f: + with open("dump1", encoding="utf-8") as f: assert f.read() == "dump_text_1" - with open("dump2", "r", encoding="utf-8") as f: + with open("dump2", encoding="utf-8") as f: assert f.read() == "dump_text_2" diff --git a/tests/ert/unit_tests/workflow_runner/test_workflow_job.py b/tests/ert/unit_tests/workflow_runner/test_workflow_job.py index db10d584687..958cbe5a583 100644 --- a/tests/ert/unit_tests/workflow_runner/test_workflow_job.py +++ b/tests/ert/unit_tests/workflow_runner/test_workflow_job.py @@ -58,7 +58,7 @@ def test_run_external_job(): assert runner.run(["test", "text"]) is None assert runner.stdoutdata() == "Hello World\n" - with open("test", "r", encoding="utf-8") as f: + with open("test", encoding="utf-8") as f: assert f.read() == "text" diff --git a/tests/everest/conftest.py b/tests/everest/conftest.py index 0974652c7b4..9db0abf0aba 100644 --- a/tests/everest/conftest.py +++ b/tests/everest/conftest.py @@ -1,9 +1,9 @@ import os import shutil import tempfile +from collections.abc import Callable, Iterator from copy import deepcopy from pathlib import Path -from typing import Callable, Dict, Iterator, Optional, Union import pytest @@ -23,8 +23,8 @@ def testdata() -> Path: @pytest.fixture def copy_testdata_tmpdir( testdata: Path, tmp_path: Path -) -> Iterator[Callable[[Optional[str]], Path]]: - def _copy_tree(path: Optional[str] = None): +) -> Iterator[Callable[[str | None], Path]]: + def _copy_tree(path: str | None = None): path_ = testdata if path is None else testdata / path shutil.copytree(path_, tmp_path, dirs_exist_ok=True) return path_ @@ -36,7 +36,7 @@ def _copy_tree(path: Optional[str] = None): @pytest.fixture(scope="module") -def control_data_no_variables() -> Dict[str, Union[str, float]]: +def control_data_no_variables() -> dict[str, str | float]: return { "name": "group_0", "type": "well_control", @@ -83,7 +83,7 @@ def control_data_no_variables() -> Dict[str, Union[str, float]]: ) def control_config( request, - control_data_no_variables: Dict[str, Union[str, float]], + control_data_no_variables: dict[str, str | float], ) -> ControlConfig: config = deepcopy(control_data_no_variables) config["variables"] = request.param diff --git a/tests/everest/entry_points/test_config_branch_entry.py b/tests/everest/entry_points/test_config_branch_entry.py index 4207f73d247..9d3c69fb1dd 100644 --- a/tests/everest/entry_points/test_config_branch_entry.py +++ b/tests/everest/entry_points/test_config_branch_entry.py @@ -85,8 +85,8 @@ def test_config_branch_preserves_config_section_order( diff_lines = [] with ( - open(CONFIG_FILE, "r", encoding="utf-8") as initial_config, - open(new_config_file_name, "r", encoding="utf-8") as branch_config, + open(CONFIG_FILE, encoding="utf-8") as initial_config, + open(new_config_file_name, encoding="utf-8") as branch_config, ): diff = difflib.unified_diff( initial_config.readlines(), diff --git a/tests/everest/entry_points/test_everexport.py b/tests/everest/entry_points/test_everexport.py index 897b45e7ae5..23913f534d5 100644 --- a/tests/everest/entry_points/test_everexport.py +++ b/tests/everest/entry_points/test_everexport.py @@ -98,7 +98,7 @@ def test_everexport_entry_empty(mocked_func, copy_math_func_test_data_to_tmp): everexport_entry([CONFIG_FILE_MINIMAL]) assert os.path.isfile(export_file_path) - with open(export_file_path, "r", encoding="utf-8") as f: + with open(export_file_path, encoding="utf-8") as f: content = f.read() assert not content.strip() diff --git a/tests/everest/functional/test_main_everest_entry.py b/tests/everest/functional/test_main_everest_entry.py index bf28d46aa34..a185c1469bd 100644 --- a/tests/everest/functional/test_main_everest_entry.py +++ b/tests/everest/functional/test_main_everest_entry.py @@ -12,10 +12,7 @@ from everest import __version__ as everest_version from everest.bin.main import start_everest from everest.config import EverestConfig, ServerConfig -from everest.detached import ( - ServerStatus, - everserver_status, -) +from everest.detached import ServerStatus, everserver_status CONFIG_FILE_MINIMAL = "config_minimal.yml" WELL_ORDER = "everest/model/config.yml" @@ -149,13 +146,13 @@ def test_everest_main_lint_entry(copy_math_func_test_data_to_tmp): with capture_streams() as (out, err), pytest.raises(SystemExit): start_everest(["everest", "lint", CONFIG_FILE_MINIMAL]) - _type = "(type=float_parsing)" + type_ = "(type=float_parsing)" validation_msg = dedent( f"""Loading config file failed with: Found 1 validation error: controls -> 0 -> initial_guess - * Input should be a valid number, unable to parse string as a number {_type} + * Input should be a valid number, unable to parse string as a number {type_} """ ) assert validation_msg in err.getvalue() diff --git a/tests/everest/test_api_snapshots.py b/tests/everest/test_api_snapshots.py index c2768f763b6..2a51b6534b3 100644 --- a/tests/everest/test_api_snapshots.py +++ b/tests/everest/test_api_snapshots.py @@ -1,7 +1,7 @@ import json from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Dict +from typing import Any import orjson import polars @@ -26,7 +26,7 @@ def _round_floats(obj, dec): return obj -def make_api_snapshot(api) -> Dict[str, Any]: +def make_api_snapshot(api) -> dict[str, Any]: api_json = { "batches": api.batches, "control_names": api.control_names, diff --git a/tests/everest/test_config_validation.py b/tests/everest/test_config_validation.py index 5bf38643e9f..2acb73af8b7 100644 --- a/tests/everest/test_config_validation.py +++ b/tests/everest/test_config_validation.py @@ -3,7 +3,7 @@ import re import warnings from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any import pytest from pydantic import ValidationError @@ -15,7 +15,7 @@ from tests.everest.utils import skipif_no_everest_models -def has_error(error: Union[ValidationError, List[dict]], match: str): +def has_error(error: ValidationError | list[dict], match: str): messages = ( [error_dict["msg"] for error_dict in error.errors()] if isinstance(error, ValidationError) @@ -282,7 +282,7 @@ def test_that_scaled_range_is_valid_range(): ), ) def test_that_invalid_control_initial_guess_outside_bounds( - variables: List[Dict[str, Any]], count: int + variables: list[dict[str, Any]], count: int ): with pytest.raises(ValueError) as e: EverestConfig.with_defaults( diff --git a/tests/everest/test_controls.py b/tests/everest/test_controls.py index a62cf91cd78..e01e1050adf 100644 --- a/tests/everest/test_controls.py +++ b/tests/everest/test_controls.py @@ -2,7 +2,6 @@ import numbers import os from copy import deepcopy -from typing import List import pytest from pydantic import ValidationError @@ -63,7 +62,7 @@ def test_controls_initialization(): def _perturb_control_zero( config: EverestConfig, gmin, gmax, ginit, fill -) -> List[ControlVariableConfig]: +) -> list[ControlVariableConfig]: """Perturbs the variable range of the first control to create interesting configurations. """ diff --git a/tests/everest/test_egg_simulation.py b/tests/everest/test_egg_simulation.py index 9ebf96746ab..0d8dd80092c 100644 --- a/tests/everest/test_egg_simulation.py +++ b/tests/everest/test_egg_simulation.py @@ -671,7 +671,7 @@ def test_run_egg_model(copy_egg_test_data_to_tmp): config = EverestConfig.load_file(CONFIG_FILE) # test callback - class CBTracker(object): + class CBTracker: def __init__(self): self.called = False @@ -802,7 +802,7 @@ def test_egg_model_wells_json_output_no_none(copy_egg_test_data_to_tmp): def test_egg_snapshot(snapshot, copy_egg_test_data_to_tmp): config = EverestConfig.load_file(CONFIG_FILE) - class CBTracker(object): + class CBTracker: def __init__(self): self.called = False diff --git a/tests/everest/test_everest_config.py b/tests/everest/test_everest_config.py index 8300479602e..ca55cb01dbb 100644 --- a/tests/everest/test_everest_config.py +++ b/tests/everest/test_everest_config.py @@ -1,6 +1,5 @@ import logging from pathlib import Path -from typing import List import pytest @@ -30,7 +29,7 @@ def test_that_control_config_is_initialized_with_control_variables(): } parsed_config = ControlConfig(**controls_dict) - assert isinstance(parsed_config.variables, List) + assert isinstance(parsed_config.variables, list) [v1, v2] = parsed_config.variables diff --git a/tests/everest/test_fm_plugins.py b/tests/everest/test_fm_plugins.py index b9eb1aef620..a4ba8a4da0b 100644 --- a/tests/everest/test_fm_plugins.py +++ b/tests/everest/test_fm_plugins.py @@ -1,6 +1,6 @@ import logging +from collections.abc import Callable, Iterator, Sequence from itertools import chain -from typing import Callable, Iterator, Sequence, Type import pluggy import pytest @@ -47,22 +47,22 @@ def test_everest_models_jobs(): def test_multiple_plugins(plugin_manager): - _SCHEMAS = [{"job1": 1}, {"job2": 2}] + SCHEMAS = [{"job1": 1}, {"job2": 2}] class Plugin1: @hookimpl def get_forward_models_schemas(self): - return [_SCHEMAS[0]] + return [SCHEMAS[0]] class Plugin2: @hookimpl def get_forward_models_schemas(self): - return [_SCHEMAS[1]] + return [SCHEMAS[1]] pm = plugin_manager(Plugin1(), Plugin2()) jobs = list(chain.from_iterable(pm.hook.get_forward_models_schemas())) - for value in _SCHEMAS: + for value in SCHEMAS: assert value in jobs @@ -72,7 +72,7 @@ class Model(BaseModel): class Plugin: @hookimpl - def parse_forward_model_schema(self, path: str, schema: Type[BaseModel]): + def parse_forward_model_schema(self, path: str, schema: type[BaseModel]): return schema.model_validate({"content": path}) pm = plugin_manager(Plugin()) diff --git a/tests/everest/test_logging.py b/tests/everest/test_logging.py index 8c81276b054..e86337a5d0a 100644 --- a/tests/everest/test_logging.py +++ b/tests/everest/test_logging.py @@ -5,10 +5,7 @@ from ert.scheduler.event import FinishedEvent from everest.config import EverestConfig, ServerConfig -from everest.detached import ( - start_server, - wait_for_server, -) +from everest.detached import start_server, wait_for_server from everest.util import makedirs_if_needed CONFIG_FILE = "config_fm_failure.yml" diff --git a/tests/everest/test_repo_configs.py b/tests/everest/test_repo_configs.py index f848c0b808e..57c8ddb45a8 100644 --- a/tests/everest/test_repo_configs.py +++ b/tests/everest/test_repo_configs.py @@ -32,7 +32,7 @@ def test_all_repo_configs(): config_folders = map(lambda fn: os.path.join(repo_dir, fn), config_folders) # noqa E731 is_yaml = lambda fn: fn.endswith(".yml") - is_data = lambda fn: any((df in fn for df in data_folders)) + is_data = lambda fn: any(df in fn for df in data_folders) def is_config(fn): return is_yaml(fn) and not is_data(fn) and "invalid" not in fn diff --git a/tests/everest/test_templating.py b/tests/everest/test_templating.py index d444971d7b9..8b5aa81f004 100644 --- a/tests/everest/test_templating.py +++ b/tests/everest/test_templating.py @@ -234,7 +234,7 @@ def test_user_specified_data_n_template( assert os.path.isfile(expected_file) # Check expected contents of file - with open(expected_file, "r", encoding="utf-8") as f: + with open(expected_file, encoding="utf-8") as f: contents = f.read() assert ( contents == "VALUE1+VALUE2" diff --git a/tests/everest/test_workflows.py b/tests/everest/test_workflows.py index 15e3cc85042..f93b10a8f87 100644 --- a/tests/everest/test_workflows.py +++ b/tests/everest/test_workflows.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from pathlib import Path -from typing import Callable, Optional import pytest @@ -33,7 +33,7 @@ def test_workflow_run(copy_mocked_test_data_to_tmp, evaluator_server_config_gene @pytest.mark.parametrize("config", ("array", "index")) def test_state_modifier_workflow_run( config: str, - copy_testdata_tmpdir: Callable[[Optional[str]], Path], + copy_testdata_tmpdir: Callable[[str | None], Path], evaluator_server_config_generator, ) -> None: cwd = copy_testdata_tmpdir("open_shut_state_modifier") diff --git a/tests/everest/utils/__init__.py b/tests/everest/utils/__init__.py index 84ad48afd89..3dc2491a579 100644 --- a/tests/everest/utils/__init__.py +++ b/tests/everest/utils/__init__.py @@ -83,7 +83,7 @@ def satisfy(predicate): https://stackoverflow.com/questions/21611559/assert-that-a-method-was-called-with-one-argument-out-of-several """ - class _PredicateChecker(object): + class _PredicateChecker: def __eq__(self, obj): return predicate(obj) @@ -100,7 +100,7 @@ def satisfy_callable(): return satisfy(callable) -class MockParser(object): +class MockParser: """ Small class that contains the necessary functions in order to test custom validation functions used with the argparse module diff --git a/tests/everest/utils/test_pydantic_doc_generation.py b/tests/everest/utils/test_pydantic_doc_generation.py index 294603d099d..738d6636246 100644 --- a/tests/everest/utils/test_pydantic_doc_generation.py +++ b/tests/everest/utils/test_pydantic_doc_generation.py @@ -9,7 +9,7 @@ def test_generated_doc(): """ committed_file = relpath("..", "..", "docs", "everest", "config_generated.rst") - with open(committed_file, "r", encoding="utf-8") as fp: + with open(committed_file, encoding="utf-8") as fp: committed_text = fp.read() generated_rst = generate_docs_pydantic_to_rst()