Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure keys can contain "/" #8818

Merged
merged 5 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ert/dark_storage/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def data_for_key(
"summary", tuple(ensemble.get_realization_list_with_responses("summary"))
)
summary_keys = summary_data["response_key"].unique().to_list()
except (ValueError, KeyError):
except (ValueError, KeyError, polars.exceptions.ColumnNotFoundError):
summary_data = polars.DataFrame()
summary_keys = []

Expand Down
4 changes: 4 additions & 0 deletions src/ert/dark_storage/endpoints/records.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
from typing import Any, Dict, List, Mapping, Union
from urllib.parse import unquote
from uuid import UUID, uuid4

import numpy as np
Expand Down Expand Up @@ -34,6 +35,7 @@ async def get_record_observations(
ensemble_id: UUID,
response_name: str,
) -> List[js.ObservationOut]:
response_name = unquote(response_name)
ensemble = storage.get_ensemble(ensemble_id)
obs_keys = get_observation_keys_for_response(ensemble, response_name)
obss = get_observations_for_obs_keys(ensemble, obs_keys)
Expand Down Expand Up @@ -74,6 +76,7 @@ async def get_ensemble_record(
ensemble_id: UUID,
accept: Annotated[Union[str, None], Header()] = None,
) -> Any:
name = unquote(name)
dataframe = data_for_key(storage.get_ensemble(ensemble_id), name)
media_type = accept if accept is not None else "text/csv"
if media_type == "application/x-parquet":
Expand Down Expand Up @@ -153,6 +156,7 @@ def get_ensemble_responses(
def get_std_dev(
*, storage: Storage = DEFAULT_STORAGE, ensemble_id: UUID, key: str, z: int
) -> Response:
key = unquote(key)
ensemble = storage.get_ensemble(ensemble_id)
try:
da = ensemble.calculate_std_dev_for_parameter(key)["values"]
Expand Down
11 changes: 8 additions & 3 deletions src/ert/gui/tools/plot/plot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import combinations as combi
from json.decoder import JSONDecodeError
from typing import Any, Dict, List, NamedTuple, Optional
from urllib.parse import quote

import httpx
import numpy as np
Expand Down Expand Up @@ -42,6 +43,10 @@ def __init__(self) -> None:
self._all_ensembles: Optional[List[EnsembleObject]] = None
self._timeout = 120

@staticmethod
def escape(s: str) -> str:
return quote(quote(s, safe=""))

def _get_ensemble_by_id(self, id: str) -> Optional[EnsembleObject]:
for ensemble in self.get_all_ensembles():
if ensemble.id == id:
Expand Down Expand Up @@ -163,7 +168,7 @@ def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame:

with StorageService.session() as client:
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}",
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}",
headers={"accept": "application/x-parquet"},
timeout=self._timeout,
)
Expand Down Expand Up @@ -196,7 +201,7 @@ def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFram

with StorageService.session() as client:
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}/observations",
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/observations",
timeout=self._timeout,
)
self._check_response(response)
Expand Down Expand Up @@ -261,7 +266,7 @@ def std_dev_for_parameter(

with StorageService.session() as client:
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}/std_dev",
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/std_dev",
params={"z": z},
timeout=self._timeout,
)
Expand Down
2 changes: 1 addition & 1 deletion src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def get_summary_keyset(self) -> List[str]:
)

return sorted(summary_data["response_key"].unique().to_list())
except (ValueError, KeyError):
except (ValueError, KeyError, polars.ColumnNotFoundError):
return []

def _load_single_dataset(
Expand Down
114 changes: 114 additions & 0 deletions tests/ert/unit_tests/dark_storage/test_dark_storage_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import gc
import io
import os
from urllib.parse import quote
from uuid import UUID

import hypothesis.strategies as st
import pandas as pd
import polars
import pytest
from hypothesis import assume
from hypothesis.stateful import rule
from starlette.testclient import TestClient

from ert.dark_storage import enkf
from ert.dark_storage.app import app
from tests.ert.unit_tests.storage.test_local_storage import StatefulStorageTest


def escape(s):
return quote(quote(quote(s, safe="")))


class DarkStorageStateTest(StatefulStorageTest):
def __init__(self):
super().__init__()
self.prev_no_token = os.environ.get("ERT_STORAGE_NO_TOKEN")
self.prev_ens_path = os.environ.get("ERT_STORAGE_ENS_PATH")
os.environ["ERT_STORAGE_NO_TOKEN"] = "yup"
os.environ["ERT_STORAGE_ENS_PATH"] = str(self.storage.path)
self.client = TestClient(app)

@rule()
def get_experiments_through_client(self):
self.client.get("/updates/storage")
response = self.client.get("/experiments")
experiment_records = response.json()
assert len(experiment_records) == len(list(self.storage.experiments))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check that this is not zero also?

Copy link
Contributor Author

@eivindjahren eivindjahren Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be empty if we haven't initialized anything in storage yet.

for record in experiment_records:
storage_experiment = self.storage.get_experiment(UUID(record["id"]))
assert {UUID(i) for i in record["ensemble_ids"]} == {
ens.id for ens in storage_experiment.ensembles
}

@rule(model_experiment=StatefulStorageTest.experiments)
def get_observations_through_client(self, model_experiment):
response = self.client.get(f"/experiments/{model_experiment.uuid}/observations")
assert {r["name"] for r in response.json()} == {
key
for _, ds in model_experiment.observations.items()
for key in ds["observation_key"]
}

@rule(model_experiment=StatefulStorageTest.experiments)
def get_ensembles_through_client(self, model_experiment):
response = self.client.get(f"/experiments/{model_experiment.uuid}/ensembles")
assert {r["id"] for r in response.json()} == {
str(uuid) for uuid in model_experiment.ensembles
}

@rule(model_ensemble=StatefulStorageTest.ensembles)
def get_responses_through_client(self, model_ensemble):
response = self.client.get(f"/ensembles/{model_ensemble.uuid}/responses")
response_names = {
k
for r in model_ensemble.response_values.values()
for k in r["response_key"]
}
assert set(response.json().keys()) == response_names

@rule(model_ensemble=StatefulStorageTest.ensembles, data=st.data())
def get_response_csv_through_client(self, model_ensemble, data):
assume(model_ensemble.response_values)
response_name, response_key = data.draw(
st.sampled_from(
[
(response_name, response_key)
for response_name, r in model_ensemble.response_values.items()
for response_key in r["response_key"]
]
)
)
df = pd.read_parquet(
io.BytesIO(
self.client.get(
f"/ensembles/{model_ensemble.uuid}/records/{escape(response_key)}",
headers={"accept": "application/x-parquet"},
).content
)
)
assert {dt[:10] for dt in df.columns} == {
str(dt)[:10]
for dt in model_ensemble.response_values[response_name].filter(
polars.col("response_key") == response_key
)["time"]
}

def teardown(self):
super().teardown()
if enkf._storage is not None:
enkf._storage.close()
enkf._storage = None
gc.collect()
if self.prev_no_token is not None:
os.environ["ERT_STORAGE_NO_TOKEN"] = self.prev_no_token
else:
del os.environ["ERT_STORAGE_NO_TOKEN"]
if self.prev_ens_path is not None:
os.environ["ERT_STORAGE_ENS_PATH"] = self.prev_ens_path
else:
del os.environ["ERT_STORAGE_ENS_PATH"]


TestDarkStorage = pytest.mark.integration_test(DarkStorageStateTest.TestCase)
6 changes: 3 additions & 3 deletions tests/ert/unit_tests/gui/tools/plot/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def mocked_requests_get(*args, **kwargs):

records = {
"/ensembles/ens_id_3/records/FOPR": summary_parquet_data,
"/ensembles/ens_id_3/records/BPR:1,3,8": summary_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:BPR_138_PERSISTENCE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE": parameter_parquet_data,
"/ensembles/ens_id_3/records/BPR%25253A1%25252C3%25252C8": summary_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253ABPR_138_PERSISTENCE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253AOP1_DIVERGENCE_SCALE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_WPR_DIFF@199": gen_parquet_data,
"/ensembles/ens_id_3/records/FOPRH": history_parquet_data,
}
Expand Down
139 changes: 138 additions & 1 deletion tests/ert/unit_tests/gui/tools/plot/test_plot_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
import gc
from datetime import datetime
from textwrap import dedent
from urllib.parse import quote

import httpx
import pandas as pd
import polars
import pytest
import xarray as xr
from pandas.testing import assert_frame_equal
from starlette.testclient import TestClient

from ert.gui.tools.plot.plot_api import PlotApiKeyDefinition
from ert.config import GenKwConfig, SummaryConfig
from ert.dark_storage import enkf
from ert.dark_storage.app import app
from ert.gui.tools.plot.plot_api import PlotApi, PlotApiKeyDefinition
from ert.services import StorageService
from ert.storage import open_storage
from tests.ert.unit_tests.gui.tools.plot.conftest import MockResponse


@pytest.fixture(autouse=True)
def use_testclient(monkeypatch):
client = TestClient(app)
monkeypatch.setattr(StorageService, "session", lambda: client)

def test_escape(s: str) -> str:
"""
Workaround for issue with TestClient:
https://github.com/encode/starlette/issues/1060
"""
return quote(quote(quote(s, safe="")))

PlotApi.escape = test_escape


def test_key_def_structure(api):
key_defs = api.all_data_type_keys()
fopr = next(x for x in key_defs if x.key == "FOPR")
Expand Down Expand Up @@ -146,3 +174,112 @@ def test_plot_api_request_errors(api):

with pytest.raises(httpx.RequestError):
api.data_for_key(ensemble.id, "should_not_be_there")


def test_plot_api_handles_urlescape(tmp_path, monkeypatch):
with open_storage(tmp_path / "storage", mode="w") as storage:
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path)
api = PlotApi()
key = "WBHP:46/3-7S"
date = datetime(year=2024, month=10, day=4)
experiment = storage.create_experiment(
parameters=[],
responses=[
SummaryConfig(
name="summary",
input_files=["CASE.UNSMRY", "CASE.SMSPEC"],
keys=[key],
)
],
observations={
"summary": polars.DataFrame(
{
"response_key": key,
"observation_key": "sumobs",
"time": polars.Series([date]).dt.cast_time_unit("ms"),
"observations": polars.Series([1.0], dtype=polars.Float32),
"std": polars.Series([1.0], dtype=polars.Float32),
}
)
},
)
ensemble = experiment.create_ensemble(ensemble_size=1, name="ensemble")
assert api.data_for_key(str(ensemble.id), key).empty
df = polars.DataFrame(
{
"response_key": [key],
"time": [polars.Series([date]).dt.cast_time_unit("ms")],
"values": [polars.Series([1.0], dtype=polars.Float32)],
}
)
df = df.explode("values", "time")
ensemble.save_response(
"summary",
df,
0,
)
assert api.data_for_key(str(ensemble.id), key).to_csv() == dedent(
"""\
Realization,2024-10-04
0,1.0
"""
)
assert api.observations_for_key([str(ensemble.id)], key).to_csv() == dedent(
"""\
,0
STD,1.0
OBS,1.0
key_index,2024-10-04 00:00:00
"""
)
if enkf._storage is not None:
enkf._storage.close()
enkf._storage = None
gc.collect()


def test_plot_api_handles_empty_gen_kw(tmp_path, monkeypatch):
with open_storage(tmp_path / "storage", mode="w") as storage:
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path)
api = PlotApi()
key = "gen_kw"
name = "<poro>"
experiment = storage.create_experiment(
parameters=[
GenKwConfig(
name=key,
forward_init=False,
update=False,
template_file=None,
output_file=None,
transform_function_definitions=[],
),
],
responses=[],
observations={},
)
ensemble = storage.create_ensemble(experiment.id, ensemble_size=10)
assert api.data_for_key(str(ensemble.id), key).empty
ensemble.save_parameters(
key,
1,
xr.Dataset(
{
"values": ("names", [1.0]),
"transformed_values": ("names", [1.0]),
"names": [name],
}
),
)
assert api.data_for_key(str(ensemble.id), key + ":" + name).to_csv() == dedent(
"""\
Realization,0
1,1.0
"""
)
if enkf._storage is not None:
enkf._storage.close()
enkf._storage = None
gc.collect()
Loading