Skip to content

Commit

Permalink
Fix an issue where aborted processes could corrupt storage
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren authored Sep 27, 2024
1 parent 7843b2f commit 321767c
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 20 deletions.
22 changes: 13 additions & 9 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def create(
started_at=datetime.now(),
)

with open(path / "index.json", mode="w", encoding="utf-8") as f:
print(index.model_dump_json(), file=f)
storage._write_transaction(
path / "index.json", index.model_dump_json().encode("utf-8")
)

return cls(storage, path, Mode.WRITE)

Expand Down Expand Up @@ -422,8 +423,9 @@ def set_failure(
error = _Failure(
type=failure_type, message=message if message else "", time=datetime.now()
)
with open(filename, mode="w", encoding="utf-8") as f:
print(error.model_dump_json(), file=f)
self._storage._write_transaction(
filename, error.model_dump_json().encode("utf-8")
)

def unset_failure(
self,
Expand Down Expand Up @@ -589,8 +591,8 @@ def load_cross_correlations(self) -> xr.Dataset:

@require_write
def save_observation_scaling_factors(self, dataset: xr.Dataset) -> None:
dataset.to_netcdf(
self.mount_point / "observation_scaling_factors.nc", engine="scipy"
self._storage._to_netcdf_transaction(
self.mount_point / "observation_scaling_factors.nc", dataset
)

def load_observation_scaling_factors(
Expand Down Expand Up @@ -620,7 +622,7 @@ def save_cross_correlations(
}
dataset = xr.Dataset(data_vars)
file_path = os.path.join(self.mount_point, "corr_XY.nc")
dataset.to_netcdf(path=file_path, engine="scipy")
self._storage._to_netcdf_transaction(file_path, dataset)

@lru_cache # noqa: B019
def load_responses(self, key: str, realizations: Tuple[int]) -> xr.Dataset:
Expand Down Expand Up @@ -820,7 +822,9 @@ def save_parameters(
path = self._realization_dir(realization) / f"{_escape_filename(group)}.nc"
path.parent.mkdir(exist_ok=True)

dataset.expand_dims(realizations=[realization]).to_netcdf(path, engine="scipy")
self._storage._to_netcdf_transaction(
path, dataset.expand_dims(realizations=[realization])
)

@require_write
def save_response(
Expand Down Expand Up @@ -855,7 +859,7 @@ def save_response(
output_path = self._realization_dir(realization)
Path.mkdir(output_path, parents=True, exist_ok=True)

data.to_netcdf(output_path / f"{response_type}.nc", engine="scipy")
self._storage._to_netcdf_transaction(output_path / f"{response_type}.nc", data)

def calculate_std_dev_for_parameter(self, parameter_group: str) -> xr.Dataset:
if parameter_group not in self.experiment.parameter_configuration:
Expand Down
22 changes: 14 additions & 8 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,30 @@ def create(
for parameter in parameters or []:
parameter.save_experiment_data(path)
parameter_data.update({parameter.name: parameter.to_dict()})
with open(path / cls._parameter_file, "w", encoding="utf-8") as f:
json.dump(parameter_data, f, indent=2)
storage._write_transaction(
path / cls._parameter_file,
json.dumps(parameter_data, indent=2).encode("utf-8"),
)

response_data = {}
for response in responses or []:
response_data.update({response.response_type: response.to_dict()})
with open(path / cls._responses_file, "w", encoding="utf-8") as f:
json.dump(response_data, f, default=str, indent=2)
storage._write_transaction(
path / cls._responses_file,
json.dumps(response_data, default=str, indent=2).encode("utf-8"),
)

if observations:
output_path = path / "observations"
output_path.mkdir()
for obs_name, dataset in observations.items():
dataset.to_netcdf(output_path / f"{obs_name}", engine="scipy")
storage._to_netcdf_transaction(output_path / f"{obs_name}", dataset)

with open(path / cls._metadata_file, "w", encoding="utf-8") as f:
simulation_data = simulation_arguments if simulation_arguments else {}
json.dump(simulation_data, f, cls=ContextBoolEncoder)
simulation_data = simulation_arguments if simulation_arguments else {}
storage._write_transaction(
path / cls._metadata_file,
json.dumps(simulation_data, cls=ContextBoolEncoder).encode("utf-8"),
)

return cls(storage, path, Mode.WRITE)

Expand Down
39 changes: 37 additions & 2 deletions src/ert/storage/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import os
import shutil
from datetime import datetime
from functools import cached_property
from pathlib import Path
from tempfile import NamedTemporaryFile
from textwrap import dedent
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -71,6 +73,7 @@ class LocalStorage(BaseMode):
LOCK_TIMEOUT = 5
EXPERIMENTS_PATH = "experiments"
ENSEMBLES_PATH = "ensembles"
SWAP_PATH = "swp"

def __init__(
self,
Expand Down Expand Up @@ -248,6 +251,10 @@ def _ensemble_path(self, ensemble_id: UUID) -> Path:
def _experiment_path(self, experiment_id: UUID) -> Path:
return self.path / self.EXPERIMENTS_PATH / str(experiment_id)

@cached_property
def _swap_path(self) -> Path:
return self.path / self.SWAP_PATH

def __enter__(self) -> LocalStorage:
return self

Expand Down Expand Up @@ -446,8 +453,10 @@ def _add_migration_information(

@require_write
def _save_index(self) -> None:
with open(self.path / "index.json", mode="w", encoding="utf-8") as f:
print(self._index.model_dump_json(indent=4), file=f)
self._write_transaction(
self.path / "index.json",
self._index.model_dump_json(indent=4).encode("utf-8"),
)

@require_write
def _migrate(self, version: int) -> None:
Expand Down Expand Up @@ -546,6 +555,32 @@ def get_unique_experiment_name(self, experiment_name: str) -> str:
else:
return experiment_name + "_0"

def _write_transaction(self, filename: str | os.PathLike[str], data: bytes) -> None:
"""
Writes the data to the filename as a transaction.
Guarantees to not leave half-written or empty files on disk if the write
fails or the process is killed.
"""
self._swap_path.mkdir(parents=True, exist_ok=True)
with NamedTemporaryFile(dir=self._swap_path, delete=False) as f:
f.write(data)
os.rename(f.name, filename)

def _to_netcdf_transaction(
self, filename: str | os.PathLike[str], dataset: xr.Dataset
) -> None:
"""
Writes the dataset to the filename as a transaction.
Guarantees to not leave half-written or empty files on disk if the write
fails or the process is killed.
"""
self._swap_path.mkdir(parents=True, exist_ok=True)
with NamedTemporaryFile(dir=self._swap_path, delete=False) as f:
dataset.to_netcdf(f, engine="scipy")
os.rename(f.name, filename)


def _storage_version(path: Path) -> int:
if not path.exists():
Expand Down
110 changes: 109 additions & 1 deletion tests/ert/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import pytest
import xarray as xr
from hypothesis import assume
from hypothesis import assume, given
from hypothesis.extra.numpy import arrays
from hypothesis.stateful import Bundle, RuleBasedStateMachine, initialize, rule

Expand Down Expand Up @@ -483,6 +483,60 @@ def fields(draw, egrid, num_fields=small_ints) -> List[Field]:
]


@pytest.mark.usefixtures("use_tmpdir")
@given(st.binary())
def test_write_transaction(data):
with open_storage(".", "w") as storage:
filepath = Path("./file.txt")
storage._write_transaction(filepath, data)

assert filepath.read_bytes() == data


class RaisingWriteNamedTemporaryFile:
entered = False

def __init__(self, *args, **kwargs):
self.wrapped = tempfile.NamedTemporaryFile(*args, **kwargs) # noqa
RaisingWriteNamedTemporaryFile.entered = False

def __enter__(self, *args, **kwargs):
self.actual_handle = self.wrapped.__enter__(*args, **kwargs)
mock_handle = MagicMock()
RaisingWriteNamedTemporaryFile.entered = True

def ctrlc(_):
raise RuntimeError()

mock_handle.write = ctrlc
return mock_handle

def __exit__(self, *args, **kwargs):
self.wrapped.__exit__(*args, **kwargs)


def test_write_transaction_failure(tmp_path):
with open_storage(tmp_path, "w") as storage:
path = tmp_path / "file.txt"
with patch(
"ert.storage.local_storage.NamedTemporaryFile",
RaisingWriteNamedTemporaryFile,
) as f, pytest.raises(RuntimeError):
storage._write_transaction(path, b"deadbeaf")

assert f.entered

assert not path.exists()


def test_write_transaction_overwrites(tmp_path):
with open_storage(tmp_path, "w") as storage:
path = tmp_path / "file.txt"
path.write_text("abc")
storage._write_transaction(path, b"deadbeaf")
assert path.read_bytes() == b"deadbeaf"


@dataclass
class Ensemble:
uuid: UUID
Expand Down Expand Up @@ -621,6 +675,34 @@ def save_field(self, model_ensemble: Ensemble, field_data):
).to_dataset(),
)

@rule(
model_ensemble=ensembles,
field_data=grid.flatmap(lambda g: arrays(np.float32, shape=g[1].shape)),
)
def write_error_in_save_field(self, model_ensemble: Ensemble, field_data):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
parameters = model_ensemble.parameter_values.values()
fields = [p for p in parameters if isinstance(p, Field)]
iens = 1
assume(not storage_ensemble.realizations_initialized([iens]))
for f in fields:
with patch(
"ert.storage.local_storage.NamedTemporaryFile",
RaisingWriteNamedTemporaryFile,
) as temp_file, pytest.raises(RuntimeError):
storage_ensemble.save_parameters(
f.name,
iens,
xr.DataArray(
field_data,
name="values",
dims=["x", "y", "z"], # type: ignore
).to_dataset(),
)

assert temp_file.entered
assert not storage_ensemble.realizations_initialized([iens])

@rule(
model_ensemble=ensembles,
)
Expand Down Expand Up @@ -783,6 +865,32 @@ def set_failure(self, model_ensemble: Ensemble, data: st.DataObject, message: st
)
model_ensemble.failure_messages[realization] = message

@rule(model_ensemble=ensembles, data=st.data(), message=st.text())
def write_error_in_set_failure(
self,
model_ensemble: Ensemble,
data: st.DataObject,
message: str,
):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
realization = data.draw(
st.integers(min_value=0, max_value=storage_ensemble.ensemble_size - 1)
)
assume(not storage_ensemble.has_failure(realization))

storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)

with patch(
"ert.storage.local_storage.NamedTemporaryFile",
RaisingWriteNamedTemporaryFile,
) as f, pytest.raises(RuntimeError):
storage_ensemble.set_failure(
realization, RealizationStorageState.PARENT_FAILURE, message
)
assert f.entered

assert not storage_ensemble.has_failure(realization)

@rule(model_ensemble=ensembles, data=st.data())
def get_failure(self, model_ensemble: Ensemble, data: st.DataObject):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
Expand Down

0 comments on commit 321767c

Please sign in to comment.