Skip to content

Commit

Permalink
Make sure that migrations use correct dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Nov 29, 2024
1 parent 8a5567d commit 82ab54a
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 27 deletions.
70 changes: 43 additions & 27 deletions src/ert/storage/migration/to8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,35 @@ def from_path(cls, path: Path) -> "ObservationDatasetInfo":
response_key = ds.attrs["response"]
response_type = "summary" if response_key == "summary" else "gen_data"

df = polars.from_pandas(ds.to_dataframe().dropna().reset_index())
df = polars.from_pandas(
ds.to_dataframe().dropna().reset_index(),
schema_overrides={
"report_step": polars.UInt16,
"index": polars.UInt16,
"observations": polars.Float32,
"std": polars.Float32,
}
if response_type == "gen_data"
else {
"time": polars.Datetime("ms"), # type: ignore
"observations": polars.Float32,
"std": polars.Float32,
},
)

df = df.with_columns(observation_key=polars.lit(observation_key))

primary_key = (
["time"] if response_type == "summary" else ["report_step", "index"]
)
if response_type == "summary":
df = df.rename({"name": "response_key"})
df = df.with_columns(polars.col("time").dt.cast_time_unit("ms"))

if response_type == "gen_data":
df = df.with_columns(
polars.col("report_step").cast(polars.UInt16),
polars.col("index").cast(polars.UInt16),
response_key=polars.lit(response_key),
)

df = df.with_columns(
[
polars.col("std").cast(polars.Float32),
polars.col("observations").cast(polars.Float32),
]
)

df = df[
["response_key", "observation_key", *primary_key, "observations", "std"]
]
Expand All @@ -71,27 +76,38 @@ def _migrate_responses_from_netcdf_to_parquet(path: Path) -> None:
real_dirs = [*ens.glob("realization-*")]

for real_dir in real_dirs:
for ds_name in ["gen_data", "summary"]:
if (real_dir / f"{ds_name}.nc").exists():
gen_data_ds = xr.open_dataset(
real_dir / f"{ds_name}.nc", engine="scipy"
for response_type, schema_overrides in [
(
"gen_data",
{
"realization": polars.UInt16,
"report_step": polars.UInt16,
"index": polars.UInt16,
"values": polars.Float32,
},
),
(
"summary",
{
"realization": polars.UInt16,
"time": polars.Datetime("ms"),
"values": polars.Float32,
},
),
]:
if (real_dir / f"{response_type}.nc").exists():
xr_ds = xr.open_dataset(
real_dir / f"{response_type}.nc",
engine="scipy",
)

pandas_df = gen_data_ds.to_dataframe().dropna().reset_index()
pandas_df = xr_ds.to_dataframe().dropna().reset_index()
polars_df = polars.from_pandas(
pandas_df,
schema_overrides={
"values": polars.Float32,
"realization": polars.UInt16,
},
schema_overrides=schema_overrides, # type: ignore
)
polars_df = polars_df.rename({"name": "response_key"})

if "time" in polars_df:
polars_df = polars_df.with_columns(
polars.col("time").dt.cast_time_unit("ms")
)

# Ensure "response_key" is the first column
polars_df = polars_df.select(
["response_key"]
Expand All @@ -101,9 +117,9 @@ def _migrate_responses_from_netcdf_to_parquet(path: Path) -> None:
if col != "response_key"
]
)
polars_df.write_parquet(real_dir / f"{ds_name}.parquet")
polars_df.write_parquet(real_dir / f"{response_type}.parquet")

os.remove(real_dir / f"{ds_name}.nc")
os.remove(real_dir / f"{response_type}.nc")


def _migrate_observations_to_grouped_parquet(path: Path) -> None:
Expand Down
96 changes: 96 additions & 0 deletions tests/ert/unit_tests/storage/test_storage_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from pathlib import Path

import numpy as np
import polars
import pytest
from packaging import version

from ert.analysis import ErtAnalysisError, smoother_update
from ert.config import ErtConfig
from ert.storage import open_storage
from ert.storage.local_storage import (
Expand Down Expand Up @@ -355,3 +357,97 @@ def test_that_migrate_blockfs_creates_backup_folder(tmp_path, caplog):
assert (
tmp_path / "storage" / "_blockfs_backup" / "ensembles" / "ens_dummy.txt"
).exists()


@pytest.mark.integration_test
@pytest.mark.usefixtures("copy_shared")
@pytest.mark.parametrize(
"ert_version",
[
"10.3.1",
"8.4.5",
"8.0.11",
"6.0.5",
"5.0.0",
],
)
def test_that_manual_update_from_migrated_storage_works(
tmp_path,
block_storage_path,
snapshot,
monkeypatch,
ert_version,
):
shutil.copytree(
block_storage_path / f"all_data_types/storage-{ert_version}",
tmp_path / "all_data_types" / f"storage-{ert_version}",
)
monkeypatch.chdir(tmp_path / "all_data_types")
ert_config = ErtConfig.with_plugins().from_file("config.ert")
local_storage_set_ert_config(ert_config)
# To make sure all tests run against the same snapshot
snapshot.snapshot_dir = snapshot.snapshot_dir.parent
with open_storage(f"storage-{ert_version}", "w") as storage:
experiments = list(storage.experiments)
assert len(experiments) == 1
experiment = experiments[0]
ensembles = list(experiment.ensembles)
assert len(ensembles) == 1
prior_ens = ensembles[0]

assert set(experiment.observations["gen_data"].schema.items()) == {
("index", polars.UInt16),
("observation_key", polars.String),
("observations", polars.Float32),
("report_step", polars.UInt16),
("response_key", polars.String),
("std", polars.Float32),
}

assert set(experiment.observations["summary"].schema.items()) == {
("observation_key", polars.String),
("observations", polars.Float32),
("response_key", polars.String),
("std", polars.Float32),
("time", polars.Datetime(time_unit="ms")),
}

prior_gendata = prior_ens.load_responses(
"gen_data", tuple(range(prior_ens.ensemble_size))
)
prior_smry = prior_ens.load_responses(
"summary", tuple(range(prior_ens.ensemble_size))
)

assert set(prior_gendata.schema.items()) == {
("response_key", polars.String),
("index", polars.UInt16),
("realization", polars.UInt16),
("report_step", polars.UInt16),
("values", polars.Float32),
}

assert set(prior_smry.schema.items()) == {
("response_key", polars.String),
("time", polars.Datetime(time_unit="ms")),
("realization", polars.UInt16),
("values", polars.Float32),
}

posterior_ens = storage.create_ensemble(
prior_ens.experiment_id,
ensemble_size=prior_ens.ensemble_size,
iteration=1,
name="posterior",
prior_ensemble=prior_ens,
)

with pytest.raises(
ErtAnalysisError, match="No active observations for update step"
):
smoother_update(
prior_ens,
posterior_ens,
list(experiment.observation_keys),
list(ert_config.ensemble_config.parameters),
)

0 comments on commit 82ab54a

Please sign in to comment.