Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Oct 2, 2024
1 parent ebf52fd commit 82c8427
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 21 deletions.
30 changes: 20 additions & 10 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
ConfigWarning,
ErrorInfo,
ForwardModelStepKeys,
HistorySource,
HookRuntime,
init_forward_model_schema,
init_site_config_schema,
Expand Down Expand Up @@ -116,6 +117,7 @@ class ErtConfig:
observation_config: List[
Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]]
] = field(default_factory=list)
enkf_obs: Optional[EnkfObs] = None

@field_validator("substitution_list", mode="before")
@classmethod
Expand All @@ -130,8 +132,6 @@ def __post_init__(self) -> None:
if self.user_config_file
else os.getcwd()
)
self.enkf_obs: EnkfObs = self._create_observations(self.observation_config)

self.observations: Dict[str, xr.Dataset] = self.enkf_obs.datasets

@staticmethod
Expand Down Expand Up @@ -307,6 +307,7 @@ def from_dict(cls, config_dict) -> Self:
errors.append(err)

try:
ensemble_config = EnsembleConfig.from_dict(config_dict=config_dict)
if obs_config_content:
summary_obs = {
obs[1].key
Expand All @@ -318,7 +319,14 @@ def from_dict(cls, config_dict) -> Self:
config_dict[ConfigKeys.SUMMARY] = [summary_keys] + [
[key] for key in summary_obs if key not in summary_keys
]
ensemble_config = EnsembleConfig.from_dict(config_dict=config_dict)
observations = cls._create_observations(
obs_config_content,
ensemble_config,
model_config.time_map,
model_config.history_source,
)
else:
observations = None
except ConfigValidationError as err:
errors.append(err)

Expand Down Expand Up @@ -351,6 +359,7 @@ def from_dict(cls, config_dict) -> Self:
model_config=model_config,
user_config_file=config_file_path,
observation_config=obs_config_content,
enkf_obs=observations,
)

@classmethod
Expand Down Expand Up @@ -954,24 +963,25 @@ def _installed_forward_model_steps_from_dict(
def preferred_num_cpu(self) -> int:
return int(self.substitution_list.get(f"<{ConfigKeys.NUM_CPU}>", 1))

@staticmethod
def _create_observations(
self,
obs_config_content: Optional[
Dict[str, Union[HistoryValues, SummaryValues, GenObsValues]]
],
ensemble_config: EnsembleConfig,
time_map: Optional[List[datetime]],
history: HistorySource,
) -> EnkfObs:
if not obs_config_content:
return EnkfObs({}, [])
obs_vectors: Dict[str, ObsVector] = {}
obs_time_list: Sequence[datetime] = []
if self.ensemble_config.refcase is not None:
obs_time_list = self.ensemble_config.refcase.all_dates
elif self.model_config.time_map is not None:
obs_time_list = self.model_config.time_map
if ensemble_config.refcase is not None:
obs_time_list = ensemble_config.refcase.all_dates
elif time_map is not None:
obs_time_list = time_map

history = self.model_config.history_source
time_len = len(obs_time_list)
ensemble_config = self.ensemble_config
config_errors: List[ErrorInfo] = []
for obs_name, values in obs_config_content:
try:
Expand Down
10 changes: 5 additions & 5 deletions src/ert/config/general_observation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List

import numpy as np
import numpy.typing as npt


@dataclass(eq=False)
class GenObservation:
values: npt.NDArray[np.double]
stds: npt.NDArray[np.double]
indices: npt.NDArray[np.int32]
std_scaling: npt.NDArray[np.double]
values: List[float]
stds: List[float]
indices: List[int]
std_scaling: List[float]

def __post_init__(self) -> None:
for val in self.stds:
Expand Down
4 changes: 3 additions & 1 deletion src/ert/config/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,9 @@ def _create_gen_obs(
f"index list ({indices}) must be of equal length",
obs_file if obs_file is not None else "",
)
return GenObservation(values, stds, indices, std_scaling)
return GenObservation(
values.tolist(), stds.tolist(), indices.tolist(), std_scaling.tolist()
)

@classmethod
def _handle_general_observation(
Expand Down
2 changes: 2 additions & 0 deletions tests/ert/unit_tests/config/observations_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,14 @@ def observations(draw, ensemble_keys, summary_keys, std_cutoff, start_date):
stop=st.integers(min_value=1, max_value=10),
error=st.floats(
min_value=0.01,
max_value=1e20,
allow_nan=False,
allow_infinity=False,
exclude_min=True,
),
error_min=st.floats(
min_value=0.0,
max_value=1e20,
allow_nan=False,
allow_infinity=False,
exclude_min=True,
Expand Down
9 changes: 4 additions & 5 deletions tests/ert/unit_tests/config/test_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from textwrap import dedent

import numpy as np
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
Expand Down Expand Up @@ -182,10 +181,10 @@ def test_summary_obs_invalid_observation_std(std):
def test_gen_obs_invalid_observation_std(std):
with pytest.raises(ValueError, match="must be strictly > 0"):
GenObservation(
np.array(range(len(std))),
np.array(std),
np.array(range(len(std))),
np.array(range(len(std))),
list(range(len(std))),
list(std),
list(range(len(std))),
list(range(len(std))),
)


Expand Down

0 comments on commit 82c8427

Please sign in to comment.