Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an issue where aborted processes could corrupt storage #8802

Merged
merged 4 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def create(
started_at=datetime.now(),
)

with open(path / "index.json", mode="w", encoding="utf-8") as f:
print(index.model_dump_json(), file=f)
storage._write_transaction(
path / "index.json", index.model_dump_json().encode("utf-8")
)

return cls(storage, path, Mode.WRITE)

Expand Down Expand Up @@ -422,8 +423,9 @@ def set_failure(
error = _Failure(
type=failure_type, message=message if message else "", time=datetime.now()
)
with open(filename, mode="w", encoding="utf-8") as f:
print(error.model_dump_json(), file=f)
self._storage._write_transaction(
filename, error.model_dump_json().encode("utf-8")
)

def unset_failure(
self,
Expand Down Expand Up @@ -589,8 +591,8 @@ def load_cross_correlations(self) -> xr.Dataset:

@require_write
def save_observation_scaling_factors(self, dataset: xr.Dataset) -> None:
dataset.to_netcdf(
self.mount_point / "observation_scaling_factors.nc", engine="scipy"
self._storage._to_netcdf_transaction(
self.mount_point / "observation_scaling_factors.nc", dataset
)

def load_observation_scaling_factors(
Expand Down Expand Up @@ -620,7 +622,7 @@ def save_cross_correlations(
}
dataset = xr.Dataset(data_vars)
file_path = os.path.join(self.mount_point, "corr_XY.nc")
dataset.to_netcdf(path=file_path, engine="scipy")
self._storage._to_netcdf_transaction(file_path, dataset)

@lru_cache # noqa: B019
def load_responses(self, key: str, realizations: Tuple[int]) -> xr.Dataset:
Expand Down Expand Up @@ -820,7 +822,9 @@ def save_parameters(
path = self._realization_dir(realization) / f"{_escape_filename(group)}.nc"
path.parent.mkdir(exist_ok=True)

dataset.expand_dims(realizations=[realization]).to_netcdf(path, engine="scipy")
self._storage._to_netcdf_transaction(
path, dataset.expand_dims(realizations=[realization])
)

@require_write
def save_response(
Expand Down Expand Up @@ -855,7 +859,7 @@ def save_response(
output_path = self._realization_dir(realization)
Path.mkdir(output_path, parents=True, exist_ok=True)

data.to_netcdf(output_path / f"{response_type}.nc", engine="scipy")
self._storage._to_netcdf_transaction(output_path / f"{response_type}.nc", data)

def calculate_std_dev_for_parameter(self, parameter_group: str) -> xr.Dataset:
if parameter_group not in self.experiment.parameter_configuration:
Expand Down
22 changes: 14 additions & 8 deletions src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,30 @@ def create(
for parameter in parameters or []:
parameter.save_experiment_data(path)
parameter_data.update({parameter.name: parameter.to_dict()})
with open(path / cls._parameter_file, "w", encoding="utf-8") as f:
json.dump(parameter_data, f, indent=2)
storage._write_transaction(
path / cls._parameter_file,
json.dumps(parameter_data, indent=2).encode("utf-8"),
)

response_data = {}
for response in responses or []:
response_data.update({response.response_type: response.to_dict()})
with open(path / cls._responses_file, "w", encoding="utf-8") as f:
json.dump(response_data, f, default=str, indent=2)
storage._write_transaction(
path / cls._responses_file,
json.dumps(response_data, default=str, indent=2).encode("utf-8"),
)

if observations:
output_path = path / "observations"
output_path.mkdir()
for obs_name, dataset in observations.items():
dataset.to_netcdf(output_path / f"{obs_name}", engine="scipy")
storage._to_netcdf_transaction(output_path / f"{obs_name}", dataset)

with open(path / cls._metadata_file, "w", encoding="utf-8") as f:
simulation_data = simulation_arguments if simulation_arguments else {}
json.dump(simulation_data, f, cls=ContextBoolEncoder)
simulation_data = simulation_arguments if simulation_arguments else {}
storage._write_transaction(
path / cls._metadata_file,
json.dumps(simulation_data, cls=ContextBoolEncoder).encode("utf-8"),
)

return cls(storage, path, Mode.WRITE)

Expand Down
39 changes: 37 additions & 2 deletions src/ert/storage/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import os
import shutil
from datetime import datetime
from functools import cached_property
from pathlib import Path
from tempfile import NamedTemporaryFile
from textwrap import dedent
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -71,6 +73,7 @@ class LocalStorage(BaseMode):
LOCK_TIMEOUT = 5
EXPERIMENTS_PATH = "experiments"
ENSEMBLES_PATH = "ensembles"
SWAP_PATH = "swp"

def __init__(
self,
Expand Down Expand Up @@ -248,6 +251,10 @@ def _ensemble_path(self, ensemble_id: UUID) -> Path:
def _experiment_path(self, experiment_id: UUID) -> Path:
return self.path / self.EXPERIMENTS_PATH / str(experiment_id)

@cached_property
def _swap_path(self) -> Path:
return self.path / self.SWAP_PATH

def __enter__(self) -> LocalStorage:
return self

Expand Down Expand Up @@ -446,8 +453,10 @@ def _add_migration_information(

@require_write
def _save_index(self) -> None:
with open(self.path / "index.json", mode="w", encoding="utf-8") as f:
print(self._index.model_dump_json(indent=4), file=f)
self._write_transaction(
self.path / "index.json",
self._index.model_dump_json(indent=4).encode("utf-8"),
)

@require_write
def _migrate(self, version: int) -> None:
Expand Down Expand Up @@ -546,6 +555,32 @@ def get_unique_experiment_name(self, experiment_name: str) -> str:
else:
return experiment_name + "_0"

def _write_transaction(self, filename: str | os.PathLike[str], data: bytes) -> None:
"""
Writes the data to the filename as a transaction.

Guarantees to not leave half-written or empty files on disk if the write
fails or the process is killed.
"""
self._swap_path.mkdir(parents=True, exist_ok=True)
with NamedTemporaryFile(dir=self._swap_path, delete=False) as f:
f.write(data)
os.rename(f.name, filename)

def _to_netcdf_transaction(
self, filename: str | os.PathLike[str], dataset: xr.Dataset
) -> None:
"""
Writes the dataset to the filename as a transaction.

Guarantees to not leave half-written or empty files on disk if the write
fails or the process is killed.
"""
self._swap_path.mkdir(parents=True, exist_ok=True)
with NamedTemporaryFile(dir=self._swap_path, delete=False) as f:
dataset.to_netcdf(f, engine="scipy")
os.rename(f.name, filename)


def _storage_version(path: Path) -> int:
if not path.exists():
Expand Down
110 changes: 109 additions & 1 deletion tests/ert/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import pytest
import xarray as xr
from hypothesis import assume
from hypothesis import assume, given
from hypothesis.extra.numpy import arrays
from hypothesis.stateful import Bundle, RuleBasedStateMachine, initialize, rule

Expand Down Expand Up @@ -483,6 +483,60 @@ def fields(draw, egrid, num_fields=small_ints) -> List[Field]:
]


@pytest.mark.usefixtures("use_tmpdir")
@given(st.binary())
def test_write_transaction(data):
with open_storage(".", "w") as storage:
filepath = Path("./file.txt")
storage._write_transaction(filepath, data)

assert filepath.read_bytes() == data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deindent one level to make sure the context manager does not delete it (which I assume it is not meant to do)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, that context manager is probably irrelevant to _write_transaction.



class RaisingWriteNamedTemporaryFile:
entered = False

def __init__(self, *args, **kwargs):
self.wrapped = tempfile.NamedTemporaryFile(*args, **kwargs) # noqa
RaisingWriteNamedTemporaryFile.entered = False

def __enter__(self, *args, **kwargs):
self.actual_handle = self.wrapped.__enter__(*args, **kwargs)
mock_handle = MagicMock()
RaisingWriteNamedTemporaryFile.entered = True

def ctrlc(_):
raise RuntimeError()

mock_handle.write = ctrlc
return mock_handle

def __exit__(self, *args, **kwargs):
self.wrapped.__exit__(*args, **kwargs)


def test_write_transaction_failure(tmp_path):
with open_storage(tmp_path, "w") as storage:
path = tmp_path / "file.txt"
with patch(
"ert.storage.local_storage.NamedTemporaryFile",
RaisingWriteNamedTemporaryFile,
) as f, pytest.raises(RuntimeError):
storage._write_transaction(path, b"deadbeaf")

assert f.entered

assert not path.exists()


def test_write_transaction_overwrites(tmp_path):
with open_storage(tmp_path, "w") as storage:
path = tmp_path / "file.txt"
path.write_text("abc")
storage._write_transaction(path, b"deadbeaf")
assert path.read_bytes() == b"deadbeaf"


@dataclass
class Ensemble:
uuid: UUID
Expand Down Expand Up @@ -621,6 +675,34 @@ def save_field(self, model_ensemble: Ensemble, field_data):
).to_dataset(),
)

@rule(
model_ensemble=ensembles,
field_data=grid.flatmap(lambda g: arrays(np.float32, shape=g[1].shape)),
)
def write_error_in_save_field(self, model_ensemble: Ensemble, field_data):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
parameters = model_ensemble.parameter_values.values()
fields = [p for p in parameters if isinstance(p, Field)]
iens = 1
assume(not storage_ensemble.realizations_initialized([iens]))
for f in fields:
with patch(
"ert.storage.local_storage.NamedTemporaryFile",
RaisingWriteNamedTemporaryFile,
) as temp_file, pytest.raises(RuntimeError):
storage_ensemble.save_parameters(
f.name,
iens,
xr.DataArray(
field_data,
name="values",
dims=["x", "y", "z"], # type: ignore
).to_dataset(),
)

assert temp_file.entered
assert not storage_ensemble.realizations_initialized([iens])

@rule(
model_ensemble=ensembles,
)
Expand Down Expand Up @@ -783,6 +865,32 @@ def set_failure(self, model_ensemble: Ensemble, data: st.DataObject, message: st
)
model_ensemble.failure_messages[realization] = message

@rule(model_ensemble=ensembles, data=st.data(), message=st.text())
def write_error_in_set_failure(
self,
model_ensemble: Ensemble,
data: st.DataObject,
message: str,
):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
realization = data.draw(
st.integers(min_value=0, max_value=storage_ensemble.ensemble_size - 1)
)
assume(not storage_ensemble.has_failure(realization))

storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)

with patch(
"ert.storage.local_storage.NamedTemporaryFile",
RaisingWriteNamedTemporaryFile,
) as f, pytest.raises(RuntimeError):
storage_ensemble.set_failure(
realization, RealizationStorageState.PARENT_FAILURE, message
)
assert f.entered

assert not storage_ensemble.has_failure(realization)

@rule(model_ensemble=ensembles, data=st.data())
def get_failure(self, model_ensemble: Ensemble, data: st.DataObject):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
Expand Down