Skip to content

Commit

Permalink
Make ErtConfig fetch its own plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen authored and yngve-sk committed Jul 10, 2024
1 parent 7f44cbb commit 3b2f306
Show file tree
Hide file tree
Showing 16 changed files with 43 additions and 79 deletions.
4 changes: 1 addition & 3 deletions src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def run_cli(args: Namespace, plugin_manager: Optional[ErtPluginManager] = None)
# the config file to be the base name of the original config
args.config = os.path.basename(args.config)

ert_config = ErtConfig.with_plugins(
plugin_manager.forward_model_steps if plugin_manager else []
).from_file(args.config)
ert_config = ErtConfig.with_plugins().from_file(args.config)

local_storage_set_ert_config(ert_config)

Expand Down
6 changes: 5 additions & 1 deletion src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing_extensions import Self

from ert.config.gen_data_config import GenDataConfig
from ert.shared.plugins import ErtPluginManager
from ert.substitution_list import SubstitutionList

from ._get_num_cpu import get_num_cpu_from_data_file
Expand Down Expand Up @@ -126,8 +127,11 @@ def __post_init__(self) -> None:

@staticmethod
def with_plugins(
forward_model_step_classes: List[Type[ForwardModelStepPlugin]],
forward_model_step_classes: Optional[List[Type[ForwardModelStepPlugin]]] = None,
) -> Type["ErtConfig"]:
if forward_model_step_classes is None:
forward_model_step_classes = ErtPluginManager().forward_model_steps

preinstalled_fm_steps: Dict[str, ForwardModelStepPlugin] = {}
for fm_step_subclass in forward_model_step_classes:
fm_step = fm_step_subclass()
Expand Down
4 changes: 1 addition & 3 deletions src/ert/gui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def _start_initial_gui_window(
# the config file to be the base name of the original config
args.config = os.path.basename(args.config)

ert_config = ErtConfig.with_plugins(
plugin_manager.forward_model_steps if plugin_manager else []
).from_file(args.config)
ert_config = ErtConfig.with_plugins().from_file(args.config)

local_storage_set_ert_config(ert_config)
except ConfigValidationError as error:
Expand Down
6 changes: 2 additions & 4 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,8 @@ def run_ertscript( # type: ignore
def from_config_file(
cls, config_file: str, read_only: bool = False
) -> "LibresFacade":
with ErtPluginContext() as ctx:
with ErtPluginContext():
return cls(
ErtConfig.with_plugins(
forward_model_step_classes=ctx.plugin_manager.forward_model_steps
).from_file(config_file),
ErtConfig.with_plugins().from_file(config_file),
read_only,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, List

from ert import ErtScript
from ert.config import ConfigValidationError
from ert.config.ert_script import ErtScript
from ert.config.parsing.config_errors import ConfigValidationError


class DisableParametersUpdate(ErtScript):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from ert import ErtScript
from ert.config.ert_script import ErtScript
from ert.exceptions import StorageError

if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Any, List, Tuple

from ert.config import ErtScript
from ert.config.ert_script import ErtScript
from ert.runpaths import Runpaths
from ert.validation import rangestring_to_mask

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, List

from ert import ErtScript
from ert.config import ConfigValidationError
from ert.config.ert_script import ErtScript
from ert.config.parsing.config_errors import ConfigValidationError


class MisfitPreprocessor(ErtScript):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from typing import TYPE_CHECKING, Dict, List, Optional, Type, no_type_check

from ert.config import ForwardModelStepPlugin
from ert.shared.plugins.plugin_manager import hook_specification

if TYPE_CHECKING:
from ert.config import ForwardModelStepPlugin
from ert.shared.plugins.plugin_response import PluginResponse


Expand Down
9 changes: 5 additions & 4 deletions src/ert/shared/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@

import pluggy

from ert.config.forward_model_step import (
ForwardModelStepDocumentation,
ForwardModelStepPlugin,
)
from ert.shared.plugins.workflow_config import WorkflowConfigs

_PLUGIN_NAMESPACE = "ert"
Expand All @@ -46,6 +42,11 @@
import ert.shared.plugins.hook_specifications # noqa

if TYPE_CHECKING:
from ert.config.forward_model_step import (
ForwardModelStepDocumentation,
ForwardModelStepPlugin,
)

from .plugin_response import PluginMetadata, PluginResponse

K = TypeVar("K")
Expand Down
7 changes: 0 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from qtpy.QtWidgets import QApplication

from _ert.threading import set_signal_handler
from ert.shared.plugins import ErtPluginManager

if sys.version_info >= (3, 9):
from importlib.resources import files
Expand Down Expand Up @@ -159,12 +158,6 @@ def snake_oil_case(setup_case):
return setup_case("snake_oil", "snake_oil.ert")


@pytest.fixture()
def ErtConfigWithPlugins():
pm = ErtPluginManager()
return ErtConfig.with_plugins(forward_model_step_classes=pm.forward_model_steps)


@pytest.fixture()
def minimum_case(use_tmpdir):
with open("minimum_config", "w", encoding="utf-8") as fout:
Expand Down
23 changes: 1 addition & 22 deletions tests/integration_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,9 @@
ITERATIVE_ENSEMBLE_SMOOTHER_MODE,
TEST_RUN_MODE,
)
from ert.shared.plugins import ErtPluginManager
from ert.storage import open_storage
from tests.unit_tests.all.plugins import dummy_plugins

from .run_cli import run_cli, run_cli_with_pm


def test_that_cli_runs_forward_model_from_plugin(tmp_path):
test_config_contents = dedent(
"""
NUM_REALIZATIONS 1
FORWARD_MODEL DummyForwardModel
"""
)
with open(tmp_path / "test.ert", "w", encoding="utf-8") as fh:
fh.write(test_config_contents)

pm = ErtPluginManager(plugins=[dummy_plugins])
run_cli_with_pm(
[TEST_RUN_MODE, "--disable-monitor", str(tmp_path / "test.ert")], pm
)
assert os.path.exists(
tmp_path / "simulations" / "realization-0" / "iter-0" / "dummy.out"
)
from .run_cli import run_cli


@pytest.mark.filterwarnings("ignore::ert.config.ConfigWarning")
Expand Down
7 changes: 3 additions & 4 deletions tests/integration_tests/test_storage_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from packaging import version

from ert.config import ErtConfig
from ert.storage import open_storage
from ert.storage.local_storage import local_storage_set_ert_config

Expand Down Expand Up @@ -93,14 +94,13 @@ def test_that_storage_matches(
snapshot,
monkeypatch,
ert_version,
ErtConfigWithPlugins,
):
shutil.copytree(
block_storage_path / f"all_data_types/storage-{ert_version}",
tmp_path / "all_data_types" / f"storage-{ert_version}",
)
monkeypatch.chdir(tmp_path / "all_data_types")
ert_config = ErtConfigWithPlugins.from_file("config.ert")
ert_config = ErtConfig.with_plugins().from_file("config.ert")
local_storage_set_ert_config(ert_config)
# To make sure all tests run against the same snapshot
snapshot.snapshot_dir = snapshot.snapshot_dir.parent
Expand Down Expand Up @@ -224,7 +224,6 @@ def test_that_storage_works_with_missing_parameters_and_responses(
snapshot,
monkeypatch,
ert_version,
ErtConfigWithPlugins,
):
storage_path = tmp_path / "all_data_types" / f"storage-{ert_version}"
shutil.copytree(
Expand All @@ -246,7 +245,7 @@ def test_that_storage_works_with_missing_parameters_and_responses(
os.remove(real_dir / "GEN.nc")

monkeypatch.chdir(tmp_path / "all_data_types")
ert_config = ErtConfigWithPlugins.from_file("config.ert")
ert_config = ErtConfig.with_plugins().from_file("config.ert")
local_storage_set_ert_config(ert_config)
# To make sure all tests run against the same snapshot
snapshot.snapshot_dir = snapshot.snapshot_dir.parent
Expand Down
32 changes: 14 additions & 18 deletions tests/unit_tests/config/test_forward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_ert_config_throws_on_missing_forward_model_step(


@pytest.mark.usefixtures("use_tmpdir")
def test_that_substitutions_can_be_done_in_job_names(ErtConfigWithPlugins):
def test_that_substitutions_can_be_done_in_job_names():
"""
Regression test for a usage case involving setting ECL100 or ECL300
that was broken by changes to forward_model substitutions.
Expand All @@ -55,14 +55,14 @@ def test_that_substitutions_can_be_done_in_job_names(ErtConfigWithPlugins):
with open(test_config_file_name, "w", encoding="utf-8") as fh:
fh.write(test_config_contents)

ert_config = ErtConfigWithPlugins.from_file(test_config_file_name)
ert_config = ErtConfig.with_plugins().from_file(test_config_file_name)
assert len(ert_config.forward_model_steps) == 1
job = ert_config.forward_model_steps[0]
assert job.name == "ECLIPSE100"


@pytest.mark.usefixtures("use_tmpdir")
def test_parsing_forward_model_with_double_dash_is_possible(ErtConfigWithPlugins):
def test_parsing_forward_model_with_double_dash_is_possible():
"""This is a regression test, making sure that we can put double dashes in strings.
The use case is that a file name is utilized that contains two consecutive hyphens,
which by the ert config parser used to be interpreted as a comment. In the new
Expand All @@ -79,7 +79,7 @@ def test_parsing_forward_model_with_double_dash_is_possible(ErtConfigWithPlugins
with open(test_config_file_name, "w", encoding="utf-8") as fh:
fh.write(test_config_contents)

res_config = ErtConfigWithPlugins.from_file(test_config_file_name)
res_config = ErtConfig.with_plugins().from_file(test_config_file_name)
assert res_config.model_config.jobname_format_string == "job_<IENS>--hei"
assert (
res_config.forward_model_steps[0].private_args["<TO>"]
Expand All @@ -88,9 +88,7 @@ def test_parsing_forward_model_with_double_dash_is_possible(ErtConfigWithPlugins


@pytest.mark.usefixtures("use_tmpdir")
def test_parsing_forward_model_with_quotes_does_not_introduce_spaces(
ErtConfigWithPlugins,
):
def test_parsing_forward_model_with_quotes_does_not_introduce_spaces():
"""this is a regression test, making sure that we do not by mistake introduce
spaces while parsing forward model lines that contain quotation marks
Expand All @@ -110,15 +108,15 @@ def test_parsing_forward_model_with_quotes_does_not_introduce_spaces(
with open(test_config_file_name, "w", encoding="utf-8") as fh:
fh.write(test_config_contents)

ert_config = ErtConfigWithPlugins.from_file(test_config_file_name)
ert_config = ErtConfig.with_plugins().from_file(test_config_file_name)
assert list(ert_config.forward_model_steps[0].private_args.values()) == [
"foo",
"smt/<foo>/bar/xx/t--s.s/yy/z/z/oo",
]


@pytest.mark.usefixtures("use_tmpdir")
def test_that_comments_are_ignored(ErtConfigWithPlugins):
def test_that_comments_are_ignored():
"""This is a regression test, making sure that we can put double dashes in strings.
The use case is that a file name is utilized that contains two consecutive hyphens,
which by the ert config parser used to be interpreted as a comment. In the new
Expand All @@ -136,7 +134,7 @@ def test_that_comments_are_ignored(ErtConfigWithPlugins):
with open(test_config_file_name, "w", encoding="utf-8") as fh:
fh.write(test_config_contents)

res_config = ErtConfigWithPlugins.from_file(test_config_file_name)
res_config = ErtConfig.with_plugins().from_file(test_config_file_name)
assert res_config.model_config.jobname_format_string == "job_<IENS>--hei"
assert (
res_config.forward_model_steps[0].private_args["<TO>"]
Expand All @@ -145,9 +143,7 @@ def test_that_comments_are_ignored(ErtConfigWithPlugins):


@pytest.mark.usefixtures("use_tmpdir")
def test_that_quotations_in_forward_model_arglist_are_handled_correctly(
ErtConfigWithPlugins,
):
def test_that_quotations_in_forward_model_arglist_are_handled_correctly():
"""This is a regression test, making sure that quoted strings behave consistently.
They should all result in the same.
See https://github.com/equinor/ert/issues/2766"""
Expand All @@ -164,7 +160,7 @@ def test_that_quotations_in_forward_model_arglist_are_handled_correctly(
with open(test_config_file_name, "w", encoding="utf-8") as fh:
fh.write(test_config_contents)

res_config = ErtConfigWithPlugins.from_file(test_config_file_name)
res_config = ErtConfig.with_plugins().from_file(test_config_file_name)

assert res_config.forward_model_steps[0].private_args["<FROM>"] == "some, thing"
assert res_config.forward_model_steps[0].private_args["<TO>"] == "some stuff"
Expand Down Expand Up @@ -215,7 +211,7 @@ def test_that_installing_two_forward_model_steps_with_the_same_name_warn():

@pytest.mark.usefixtures("use_tmpdir")
def test_that_forward_model_substitution_does_not_warn_about_reaching_max_iterations(
caplog, ErtConfigWithPlugins
caplog,
):
test_config_file_name = "test.ert"
test_config_contents = dedent(
Expand All @@ -227,7 +223,7 @@ def test_that_forward_model_substitution_does_not_warn_about_reaching_max_iterat
with open(test_config_file_name, "w", encoding="utf-8") as fh:
fh.write(test_config_contents)

ert_config = ErtConfigWithPlugins.from_file(test_config_file_name)
ert_config = ErtConfig.with_plugins().from_file(test_config_file_name)
with caplog.at_level(logging.WARNING):
ert_config.forward_model_data_to_json(0, 0, 0)
assert "Reached max iterations" not in caplog.text
Expand All @@ -254,7 +250,7 @@ def test_that_installing_two_forward_model_steps_with_the_same_name_warn_with_di


@pytest.mark.usefixtures("use_tmpdir")
def test_that_spaces_in_forward_model_args_are_dropped(ErtConfigWithPlugins):
def test_that_spaces_in_forward_model_args_are_dropped():
test_config_file_name = "test.ert"
# Intentionally inserted several spaces before comma
test_config_contents = dedent(
Expand All @@ -266,7 +262,7 @@ def test_that_spaces_in_forward_model_args_are_dropped(ErtConfigWithPlugins):
with open(test_config_file_name, "w", encoding="utf-8") as fh:
fh.write(test_config_contents)

ert_config = ErtConfigWithPlugins.from_file(test_config_file_name)
ert_config = ErtConfig.with_plugins().from_file(test_config_file_name)
assert len(ert_config.forward_model_steps) == 1
job = ert_config.forward_model_steps[0]
assert job.private_args.get("<VERSION>") == "smersion"
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/shared/share/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,8 @@ def minimal_case(tmpdir):
yield


def test_shell_script_jobs_availability(ErtConfigWithPlugins, minimal_case):
ert_config = ErtConfigWithPlugins.from_file("config.ert")
def test_shell_script_jobs_availability(minimal_case):
ert_config = ErtConfig.with_plugins().from_file("config.ert")
fm_shell_jobs = {}
for fm_step in ert_config.installed_forward_model_steps.values():
exe = fm_step.executable
Expand Down
6 changes: 2 additions & 4 deletions tests/unit_tests/test_run_path_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,7 @@ def test_that_data_file_sets_num_cpu(eclipse_data, expected_cpus):
"ignore:.*RUNPATH keyword contains deprecated value placeholders.*:ert.config.ConfigWarning"
)
@pytest.mark.usefixtures("use_tmpdir")
def test_that_deprecated_runpath_substitution_remain_valid(
prior_ensemble, ErtConfigWithPlugins
):
def test_that_deprecated_runpath_substitution_remain_valid(prior_ensemble):
"""
This checks that deprecated runpath substitution, using %d, remain intact.
"""
Expand All @@ -348,7 +346,7 @@ def test_that_deprecated_runpath_substitution_remain_valid(
)
Path("config.ert").write_text(config_text, encoding="utf-8")

ert_config = ErtConfigWithPlugins.from_file("config.ert")
ert_config = ErtConfig.with_plugins().from_file("config.ert")

run_context = ensemble_context(
prior_ensemble,
Expand Down

0 comments on commit 3b2f306

Please sign in to comment.