Skip to content

Commit

Permalink
Add more tests of error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Jan 8, 2024
1 parent 97dcacf commit 68fc26d
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 74 deletions.
100 changes: 60 additions & 40 deletions src/ert/config/_read_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime, timedelta
from enum import Enum, auto
from fnmatch import fnmatch
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -61,7 +61,7 @@ def from_keyword(cls, summary_keyword: str) -> _SummaryType:
"W": cls.WELL,
}
if summary_keyword == "":
raise ValueError("Got empty summary_keyword")
raise ValueError("Got empty summary keyword")
if any(special in summary_keyword for special in SPECIAL_KEYWORDS):
return cls.OTHER
if summary_keyword[0] in KEYWORD_TYPE_MAPPING:
Expand Down Expand Up @@ -93,9 +93,12 @@ def _cell_index(
return array_index + 1, j + 1, k + 1


T = TypeVar("T")


def _check_if_missing(
keyword_name: str, missing_key: str, *test_vars: Optional[int]
) -> List[int]:
keyword_name: str, missing_key: str, *test_vars: Optional[T]
) -> List[T]:
if any(v is None for v in test_vars):
raise ValueError(
f"Found {keyword_name} keyword in summary "
Expand Down Expand Up @@ -149,13 +152,17 @@ def make_summary_key(
r2 = ((number - r1) // 32768) - 10
return f"{keyword}:{r1}-{r2}"
if sum_type == _SummaryType.LOCAL_WELL:
(name,) = _check_if_missing("local well", "WGNAMES", name)
(lgr_name,) = _check_if_missing("local well", "LGRS", lgr_name)
return f"{keyword}:{lgr_name}:{name}"
if sum_type == _SummaryType.LOCAL_BLOCK:
li, lj, lk = _check_if_missing("local block", "NUMLX", li, lj, lk)
(lgr_name,) = _check_if_missing("local block", "LGRS", lgr_name)
return f"{keyword}:{lgr_name}:{li},{lj},{lk}"
if sum_type == _SummaryType.LOCAL_COMPLETION:
nx, ny = _check_if_missing("local completion", "dimens", nx, ny)
(number,) = _check_if_missing("local completion", "nums", number)
i, j, k = _cell_index(number - 1, nx, ny)
li, lj, lk = _check_if_missing("local completion", "NUMLX", li, lj, lk)
(name,) = _check_if_missing("local completion", "WGNAMES", name)
(lgr_name,) = _check_if_missing("local completion", "LGRS", lgr_name)
return f"{keyword}:{lgr_name}:{name}:{li},{lj},{lk}"
if sum_type == _SummaryType.NETWORK:
# This is consistent with resinsight but
Expand Down Expand Up @@ -200,7 +207,7 @@ def _find_file_matching(
raise ValueError(f"Could not find any {kind} matching case path {case}")
if len(candidates) > 1:
raise ValueError(
f"Ambigous reference to {kind} in {case}, could be any of {candidates}"
f"Ambiguous reference to {kind} in {case}, could be any of {candidates}"
)
return os.path.join(dir, candidates[0])

Expand All @@ -215,10 +222,13 @@ def read_summary(
filepath: str, fetch_keys: Sequence[str]
) -> Tuple[List[str], Sequence[datetime], Any]:
summary, spec = _get_summary_filenames(filepath)
date_index, start_date, date_units, keys, indices = _read_spec(spec, fetch_keys)
fetched, time_map = _read_summary(
summary, start_date, date_units, indices, date_index
)
try:
date_index, start_date, date_units, keys, indices = _read_spec(spec, fetch_keys)
fetched, time_map = _read_summary(
summary, start_date, date_units, indices, date_index
)
except resfo.ResfoParsingError as err:
raise ValueError(f"Failed to read summary file {filepath}: {err}") from err
return (keys, time_map, fetched)


Expand All @@ -228,6 +238,14 @@ def _key2str(key: Union[bytes, str]) -> str:
return ret.strip()


def _check_vals(
kw: str, spec: str, vals: Union[npt.NDArray[Any], resfo.MESS]
) -> npt.NDArray[Any]:
if vals is resfo.MESS or isinstance(vals, resfo.MESS):
raise ValueError(f"{kw.strip()} in {spec} has incorrect type MESS")
return vals


def _read_spec(
spec: str, fetch_keys: Sequence[str]
) -> Tuple[int, datetime, DateUnit, List[str], npt.NDArray[np.int32]]:
Expand All @@ -245,7 +263,7 @@ def _read_spec(
"NUMLX ",
"NUMLY ",
"NUMLZ ",
"LGRNAMES",
"LGRS ",
"UNITS ",
]
}
Expand Down Expand Up @@ -274,50 +292,48 @@ def _read_spec(
break
kw = entry.read_keyword()
if kw in arrays:
vals = entry.read_array()
if vals is resfo.MESS or isinstance(vals, resfo.MESS):
raise ValueError(f"{kw} in {spec} was MESS")
arrays[kw] = vals
arrays[kw] = _check_vals(kw, spec, entry.read_array())
if kw == "DIMENS ":
vals = entry.read_array()
if vals is resfo.MESS or isinstance(vals, resfo.MESS):
raise ValueError(f"DIMENS in {spec} was MESS")
vals = _check_vals(kw, spec, entry.read_array())
size = len(vals)
n = vals[0] if size > 0 else None
nx = vals[1] if size > 1 else None
ny = vals[2] if size > 2 else None
if kw == "STARTDAT":
vals = entry.read_array()
if vals is resfo.MESS or isinstance(vals, resfo.MESS):
raise ValueError(f"Startdate in {spec} was MESS")
vals = _check_vals(kw, spec, entry.read_array())
size = len(vals)
day = vals[0] if size > 0 else 0
month = vals[1] if size > 1 else 0
year = vals[2] if size > 2 else 0
hour = vals[3] if size > 3 else 0
minute = vals[4] if size > 4 else 0
microsecond = vals[5] if size > 5 else 0
date = datetime(
day=day,
month=month,
year=year,
hour=hour,
minute=minute,
second=microsecond // 10**6,
microsecond=microsecond % 10**6,
)
try:
date = datetime(
day=day,
month=month,
year=year,
hour=hour,
minute=minute,
second=microsecond // 10**6,
microsecond=microsecond % 10**6,
)
except Exception as err:
raise ValueError(
f"SMSPEC {spec} contains invalid STARTDAT: {err}"
) from err
keywords = arrays["KEYWORDS"]
wgnames = arrays["WGNAMES "]
nums = arrays["NUMS "]
numlx = arrays["NUMLX "]
numly = arrays["NUMLY "]
numlz = arrays["NUMLZ "]
lgr_names = arrays["LGRNAMES"]
lgr_names = arrays["LGRS "]

if date is None:
raise ValueError(f"keyword startdat missing in {spec}")
raise ValueError(f"Keyword startdat missing in {spec}")
if keywords is None:
raise ValueError(f"keywords missing in {spec}")
raise ValueError(f"Keywords missing in {spec}")
if n is None:
n = len(keywords)

Expand Down Expand Up @@ -367,16 +383,22 @@ def optional_get(arr: Optional[npt.NDArray[Any]], idx: int) -> Any:

units = arrays["UNITS "]
if units is None:
raise ValueError(f"keyword units missing in {spec}")
raise ValueError(f"Keyword units missing in {spec}")
if date_index is None:
raise ValueError(f"KEYWORDS did not contain TIME in {spec}")
if date_index >= len(units):
raise ValueError(f"Unit missing for TIME in {spec}")

unit_key = _key2str(units[date_index])
try:
date_unit = DateUnit[unit_key]
except KeyError:
raise ValueError(f"Unknown date unit in {spec}: {unit_key}") from None

return (
date_index,
date,
DateUnit[_key2str(units[date_index])],
date_unit,
list(keys_array),
indices_array,
)
Expand All @@ -403,9 +425,7 @@ def _read_summary(
def read_params() -> None:
nonlocal last_params, values
if last_params is not None:
vals = last_params.read_array()
if vals is resfo.MESS or isinstance(vals, resfo.MESS):
raise ValueError(f"PARAMS in {summary} was MESS")
vals = _check_vals("PARAMS", summary, last_params.read_array())
values.append(vals[indices])
dates.append(start_date + unit.make_delta(float(vals[date_index])))
last_params = None
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/config/config_dict_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def ert_config_values(draw, use_eclbase=booleans):
smspecs(
sum_keys=st.just(sum_keys),
start_date=st.just(Date.from_datetime(first_date)),
use_days=st.just(True),
)
)
std_cutoff = draw(small_floats)
Expand Down
33 changes: 20 additions & 13 deletions tests/unit_tests/config/summary_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,14 @@ def root_memnonic(draw):
second_character = draw(st.sampled_from("OWGVLPT"))
third_character = draw(st.sampled_from("PIF"))
fourth_character = draw(st.sampled_from("RT"))
# local = draw(st.sampled_from(["", "L"])) if first_character in "BCW" else ""
return first_character + second_character + third_character + fourth_character
local = draw(st.sampled_from(["", "L"])) if first_character in "BCW" else ""
return (
local
+ first_character
+ second_character
+ third_character
+ fourth_character
)


@st.composite
Expand Down Expand Up @@ -228,7 +234,6 @@ class Smspec:
region_numbers: List[int]
units: List[str]
start_date: Date
lgr_names: Optional[List[str]] = None
lgrs: Optional[List[str]] = None
numlx: Optional[List[PositiveInt]] = None
numly: Optional[List[PositiveInt]] = None
Expand Down Expand Up @@ -266,7 +271,6 @@ def to_ecl(self) -> List[Tuple[str, Any]]:
("STARTDAT", np.array(self.start_date.to_ecl(), dtype=np.int32)),
]
+ ([("LGRS ", self.lgrs)] if self.lgrs is not None else [])
+ ([("LGRNAMES", self.lgr_names)] if self.lgr_names is not None else [])
+ ([("NUMLX ", self.numlx)] if self.numlx is not None else [])
+ ([("NUMLY ", self.numly)] if self.numly is not None else [])
+ ([("NUMLZ ", self.numlz)] if self.numlz is not None else [])
Expand All @@ -281,36 +285,40 @@ def to_file(self, filelike, file_format: resfo.Format = resfo.Format.UNFORMATTED


@st.composite
def smspecs(
draw,
sum_keys,
start_date,
):
def smspecs(draw, sum_keys, start_date, use_days=None):
"""
Strategy for smspec that ensures that the TIME parameter, as required by
ert, is in the parameters list.
"""
use_days = st.booleans() if use_days is None else use_days
use_locals = draw(st.booleans())
sum_keys = draw(sum_keys)
if any(sk.startswith("L") for sk in sum_keys):
use_locals = True
n = len(sum_keys) + 1
nx = draw(small_ints)
ny = draw(small_ints)
nz = draw(small_ints)
keywords = ["TIME "] + sum_keys
units = ["DAYS "] + draw(st.lists(unit_names, min_size=n - 1, max_size=n - 1))
if draw(use_days):
units = ["DAYS "] + draw(
st.lists(unit_names, min_size=n - 1, max_size=n - 1)
)
else:
units = ["HOURS "] + draw(
st.lists(unit_names, min_size=n - 1, max_size=n - 1)
)
well_names = [":+:+:+:+"] + draw(st.lists(names, min_size=n - 1, max_size=n - 1))
if use_locals: # use local
lgrs = draw(st.lists(names, min_size=n, max_size=n))
numlx = draw(st.lists(small_ints, min_size=n, max_size=n))
numly = draw(st.lists(small_ints, min_size=n, max_size=n))
numlz = draw(st.lists(small_ints, min_size=n, max_size=n))
lgr_names = list(set(lgrs))
else:
lgrs = None
numlx = None
numly = None
numlz = None
lgr_names = None
region_numbers = [-32676] + draw(
st.lists(
from_dtype(np.dtype(np.int32), min_value=1, max_value=nx * ny * nz),
Expand All @@ -336,7 +344,6 @@ def smspecs(
numlx=st.just(numlx),
numly=st.just(numly),
numlz=st.just(numlz),
lgr_names=st.just(lgr_names),
region_numbers=st.just(region_numbers),
units=st.just(units),
start_date=start_date,
Expand Down
Loading

0 comments on commit 68fc26d

Please sign in to comment.