Skip to content

Commit

Permalink
Test transaction in StatefulStorageTest
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Sep 26, 2024
1 parent f7778f0 commit e16566a
Showing 1 changed file with 61 additions and 1 deletion.
62 changes: 61 additions & 1 deletion tests/ert/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,16 @@ def test_write_transaction(data):


class RaisingWriteNamedTemporaryFile:
entered = False

def __init__(self, *args, **kwargs):
self.wrapped = tempfile.NamedTemporaryFile(*args, **kwargs) # noqa
RaisingWriteNamedTemporaryFile.entered = False

def __enter__(self, *args, **kwargs):
self.actual_handle = self.wrapped.__enter__(*args, **kwargs)
mock_handle = MagicMock()
RaisingWriteNamedTemporaryFile.entered = True

def ctrlc(_):
raise RuntimeError()
Expand All @@ -517,9 +521,11 @@ def test_write_transaction_failure(tmp_path):
with patch(
"ert.storage.local_storage.NamedTemporaryFile",
RaisingWriteNamedTemporaryFile,
), pytest.raises(RuntimeError):
) as f, pytest.raises(RuntimeError):
storage._write_transaction(path, b"deadbeaf")

assert f.entered

assert not path.exists()


Expand Down Expand Up @@ -669,6 +675,34 @@ def save_field(self, model_ensemble: Ensemble, field_data):
).to_dataset(),
)

@rule(
model_ensemble=ensembles,
field_data=grid.flatmap(lambda g: arrays(np.float32, shape=g[1].shape)),
)
def write_error_in_save_field(self, model_ensemble: Ensemble, field_data):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
parameters = model_ensemble.parameter_values.values()
fields = [p for p in parameters if isinstance(p, Field)]
iens = 1
assume(not storage_ensemble.realizations_initialized([iens]))
for f in fields:
with patch(
"ert.storage.local_storage.NamedTemporaryFile",
RaisingWriteNamedTemporaryFile,
) as temp_file, pytest.raises(RuntimeError):
storage_ensemble.save_parameters(
f.name,
iens,
xr.DataArray(
field_data,
name="values",
dims=["x", "y", "z"], # type: ignore
).to_dataset(),
)

assert temp_file.entered
assert not storage_ensemble.realizations_initialized([iens])

@rule(
model_ensemble=ensembles,
)
Expand Down Expand Up @@ -831,6 +865,32 @@ def set_failure(self, model_ensemble: Ensemble, data: st.DataObject, message: st
)
model_ensemble.failure_messages[realization] = message

@rule(model_ensemble=ensembles, data=st.data(), message=st.text())
def write_error_in_set_failure(
self,
model_ensemble: Ensemble,
data: st.DataObject,
message: str,
):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
realization = data.draw(
st.integers(min_value=0, max_value=storage_ensemble.ensemble_size - 1)
)
assume(not storage_ensemble.has_failure(realization))

storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)

with patch(
"ert.storage.local_storage.NamedTemporaryFile",
RaisingWriteNamedTemporaryFile,
) as f, pytest.raises(RuntimeError):
storage_ensemble.set_failure(
realization, RealizationStorageState.PARENT_FAILURE, message
)
assert f.entered

assert not storage_ensemble.has_failure(realization)

@rule(model_ensemble=ensembles, data=st.data())
def get_failure(self, model_ensemble: Ensemble, data: st.DataObject):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
Expand Down

0 comments on commit e16566a

Please sign in to comment.