diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index 90371976022..7d354fdd82b 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -143,8 +143,9 @@ def create( started_at=datetime.now(), ) - with open(path / "index.json", mode="w", encoding="utf-8") as f: - print(index.model_dump_json(), file=f) + storage._write_transaction( + path / "index.json", index.model_dump_json().encode("utf-8") + ) return cls(storage, path, Mode.WRITE) @@ -422,8 +423,9 @@ def set_failure( error = _Failure( type=failure_type, message=message if message else "", time=datetime.now() ) - with open(filename, mode="w", encoding="utf-8") as f: - print(error.model_dump_json(), file=f) + self._storage._write_transaction( + filename, error.model_dump_json().encode("utf-8") + ) def unset_failure( self, @@ -589,8 +591,8 @@ def load_cross_correlations(self) -> xr.Dataset: @require_write def save_observation_scaling_factors(self, dataset: xr.Dataset) -> None: - dataset.to_netcdf( - self.mount_point / "observation_scaling_factors.nc", engine="scipy" + self._storage._to_netcdf_transaction( + self.mount_point / "observation_scaling_factors.nc", dataset ) def load_observation_scaling_factors( @@ -620,7 +622,7 @@ def save_cross_correlations( } dataset = xr.Dataset(data_vars) file_path = os.path.join(self.mount_point, "corr_XY.nc") - dataset.to_netcdf(path=file_path, engine="scipy") + self._storage._to_netcdf_transaction(file_path, dataset) @lru_cache # noqa: B019 def load_responses(self, key: str, realizations: Tuple[int]) -> xr.Dataset: @@ -820,7 +822,9 @@ def save_parameters( path = self._realization_dir(realization) / f"{_escape_filename(group)}.nc" path.parent.mkdir(exist_ok=True) - dataset.expand_dims(realizations=[realization]).to_netcdf(path, engine="scipy") + self._storage._to_netcdf_transaction( + path, dataset.expand_dims(realizations=[realization]) + ) @require_write def save_response( @@ -855,7 +859,7 @@ def save_response( output_path = self._realization_dir(realization) Path.mkdir(output_path, parents=True, exist_ok=True) - data.to_netcdf(output_path / f"{response_type}.nc", engine="scipy") + self._storage._to_netcdf_transaction(output_path / f"{response_type}.nc", data) def calculate_std_dev_for_parameter(self, parameter_group: str) -> xr.Dataset: if parameter_group not in self.experiment.parameter_configuration: diff --git a/src/ert/storage/local_experiment.py b/src/ert/storage/local_experiment.py index 7452dc1a0ef..238aeeefbee 100644 --- a/src/ert/storage/local_experiment.py +++ b/src/ert/storage/local_experiment.py @@ -129,24 +129,30 @@ def create( for parameter in parameters or []: parameter.save_experiment_data(path) parameter_data.update({parameter.name: parameter.to_dict()}) - with open(path / cls._parameter_file, "w", encoding="utf-8") as f: - json.dump(parameter_data, f, indent=2) + storage._write_transaction( + path / cls._parameter_file, + json.dumps(parameter_data, indent=2).encode("utf-8"), + ) response_data = {} for response in responses or []: response_data.update({response.response_type: response.to_dict()}) - with open(path / cls._responses_file, "w", encoding="utf-8") as f: - json.dump(response_data, f, default=str, indent=2) + storage._write_transaction( + path / cls._responses_file, + json.dumps(response_data, default=str, indent=2).encode("utf-8"), + ) if observations: output_path = path / "observations" output_path.mkdir() for obs_name, dataset in observations.items(): - dataset.to_netcdf(output_path / f"{obs_name}", engine="scipy") + storage._to_netcdf_transaction(output_path / f"{obs_name}", dataset) - with open(path / cls._metadata_file, "w", encoding="utf-8") as f: - simulation_data = simulation_arguments if simulation_arguments else {} - json.dump(simulation_data, f, cls=ContextBoolEncoder) + simulation_data = simulation_arguments if simulation_arguments else {} + storage._write_transaction( + path / cls._metadata_file, + json.dumps(simulation_data, cls=ContextBoolEncoder).encode("utf-8"), + ) return cls(storage, path, Mode.WRITE) diff --git a/src/ert/storage/local_storage.py b/src/ert/storage/local_storage.py index 43d06e8e827..2ad600e7f80 100644 --- a/src/ert/storage/local_storage.py +++ b/src/ert/storage/local_storage.py @@ -6,7 +6,9 @@ import os import shutil from datetime import datetime +from functools import cached_property from pathlib import Path +from tempfile import NamedTemporaryFile from textwrap import dedent from types import TracebackType from typing import ( @@ -71,6 +73,7 @@ class LocalStorage(BaseMode): LOCK_TIMEOUT = 5 EXPERIMENTS_PATH = "experiments" ENSEMBLES_PATH = "ensembles" + SWAP_PATH = "swp" def __init__( self, @@ -248,6 +251,10 @@ def _ensemble_path(self, ensemble_id: UUID) -> Path: def _experiment_path(self, experiment_id: UUID) -> Path: return self.path / self.EXPERIMENTS_PATH / str(experiment_id) + @cached_property + def _swap_path(self) -> Path: + return self.path / self.SWAP_PATH + def __enter__(self) -> LocalStorage: return self @@ -446,8 +453,10 @@ def _add_migration_information( @require_write def _save_index(self) -> None: - with open(self.path / "index.json", mode="w", encoding="utf-8") as f: - print(self._index.model_dump_json(indent=4), file=f) + self._write_transaction( + self.path / "index.json", + self._index.model_dump_json(indent=4).encode("utf-8"), + ) @require_write def _migrate(self, version: int) -> None: @@ -546,6 +555,32 @@ def get_unique_experiment_name(self, experiment_name: str) -> str: else: return experiment_name + "_0" + def _write_transaction(self, filename: str | os.PathLike[str], data: bytes) -> None: + """ + Writes the data to the filename as a transaction. + + Guarantees to not leave half-written or empty files on disk if the write + fails or the process is killed. + """ + self._swap_path.mkdir(parents=True, exist_ok=True) + with NamedTemporaryFile(dir=self._swap_path, delete=False) as f: + f.write(data) + os.rename(f.name, filename) + + def _to_netcdf_transaction( + self, filename: str | os.PathLike[str], dataset: xr.Dataset + ) -> None: + """ + Writes the dataset to the filename as a transaction. + + Guarantees to not leave half-written or empty files on disk if the write + fails or the process is killed. + """ + self._swap_path.mkdir(parents=True, exist_ok=True) + with NamedTemporaryFile(dir=self._swap_path, delete=False) as f: + dataset.to_netcdf(f, engine="scipy") + os.rename(f.name, filename) + def _storage_version(path: Path) -> int: if not path.exists(): diff --git a/tests/ert/unit_tests/storage/test_local_storage.py b/tests/ert/unit_tests/storage/test_local_storage.py index c50bad3f9b3..a5bc5dd2957 100644 --- a/tests/ert/unit_tests/storage/test_local_storage.py +++ b/tests/ert/unit_tests/storage/test_local_storage.py @@ -13,7 +13,7 @@ import numpy as np import pytest import xarray as xr -from hypothesis import assume +from hypothesis import assume, given from hypothesis.extra.numpy import arrays from hypothesis.stateful import Bundle, RuleBasedStateMachine, initialize, rule @@ -483,6 +483,60 @@ def fields(draw, egrid, num_fields=small_ints) -> List[Field]: ] +@pytest.mark.usefixtures("use_tmpdir") +@given(st.binary()) +def test_write_transaction(data): + with open_storage(".", "w") as storage: + filepath = Path("./file.txt") + storage._write_transaction(filepath, data) + + assert filepath.read_bytes() == data + + +class RaisingWriteNamedTemporaryFile: + entered = False + + def __init__(self, *args, **kwargs): + self.wrapped = tempfile.NamedTemporaryFile(*args, **kwargs) # noqa + RaisingWriteNamedTemporaryFile.entered = False + + def __enter__(self, *args, **kwargs): + self.actual_handle = self.wrapped.__enter__(*args, **kwargs) + mock_handle = MagicMock() + RaisingWriteNamedTemporaryFile.entered = True + + def ctrlc(_): + raise RuntimeError() + + mock_handle.write = ctrlc + return mock_handle + + def __exit__(self, *args, **kwargs): + self.wrapped.__exit__(*args, **kwargs) + + +def test_write_transaction_failure(tmp_path): + with open_storage(tmp_path, "w") as storage: + path = tmp_path / "file.txt" + with patch( + "ert.storage.local_storage.NamedTemporaryFile", + RaisingWriteNamedTemporaryFile, + ) as f, pytest.raises(RuntimeError): + storage._write_transaction(path, b"deadbeaf") + + assert f.entered + + assert not path.exists() + + +def test_write_transaction_overwrites(tmp_path): + with open_storage(tmp_path, "w") as storage: + path = tmp_path / "file.txt" + path.write_text("abc") + storage._write_transaction(path, b"deadbeaf") + assert path.read_bytes() == b"deadbeaf" + + @dataclass class Ensemble: uuid: UUID @@ -621,6 +675,34 @@ def save_field(self, model_ensemble: Ensemble, field_data): ).to_dataset(), ) + @rule( + model_ensemble=ensembles, + field_data=grid.flatmap(lambda g: arrays(np.float32, shape=g[1].shape)), + ) + def write_error_in_save_field(self, model_ensemble: Ensemble, field_data): + storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid) + parameters = model_ensemble.parameter_values.values() + fields = [p for p in parameters if isinstance(p, Field)] + iens = 1 + assume(not storage_ensemble.realizations_initialized([iens])) + for f in fields: + with patch( + "ert.storage.local_storage.NamedTemporaryFile", + RaisingWriteNamedTemporaryFile, + ) as temp_file, pytest.raises(RuntimeError): + storage_ensemble.save_parameters( + f.name, + iens, + xr.DataArray( + field_data, + name="values", + dims=["x", "y", "z"], # type: ignore + ).to_dataset(), + ) + + assert temp_file.entered + assert not storage_ensemble.realizations_initialized([iens]) + @rule( model_ensemble=ensembles, ) @@ -783,6 +865,32 @@ def set_failure(self, model_ensemble: Ensemble, data: st.DataObject, message: st ) model_ensemble.failure_messages[realization] = message + @rule(model_ensemble=ensembles, data=st.data(), message=st.text()) + def write_error_in_set_failure( + self, + model_ensemble: Ensemble, + data: st.DataObject, + message: str, + ): + storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid) + realization = data.draw( + st.integers(min_value=0, max_value=storage_ensemble.ensemble_size - 1) + ) + assume(not storage_ensemble.has_failure(realization)) + + storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid) + + with patch( + "ert.storage.local_storage.NamedTemporaryFile", + RaisingWriteNamedTemporaryFile, + ) as f, pytest.raises(RuntimeError): + storage_ensemble.set_failure( + realization, RealizationStorageState.PARENT_FAILURE, message + ) + assert f.entered + + assert not storage_ensemble.has_failure(realization) + @rule(model_ensemble=ensembles, data=st.data()) def get_failure(self, model_ensemble: Ensemble, data: st.DataObject): storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)