diff --git a/src/ert/dark_storage/endpoints/records.py b/src/ert/dark_storage/endpoints/records.py index 8c1ca4aadfd..0083a3cc3c0 100644 --- a/src/ert/dark_storage/endpoints/records.py +++ b/src/ert/dark_storage/endpoints/records.py @@ -1,5 +1,6 @@ import io from typing import Any, Dict, List, Mapping, Union +from urllib.parse import unquote from uuid import UUID, uuid4 import numpy as np @@ -33,6 +34,7 @@ async def get_record_observations( ensemble_id: UUID, response_name: str, ) -> List[js.ObservationOut]: + response_name = unquote(response_name) ensemble = storage.get_ensemble(ensemble_id) obs_keys = get_observation_keys_for_response(ensemble, response_name) obss = get_observations_for_obs_keys(ensemble, obs_keys) @@ -72,6 +74,7 @@ async def get_ensemble_record( ensemble_id: UUID, accept: Annotated[Union[str, None], Header()] = None, ) -> Any: + name = unquote(name) dataframe = data_for_key(storage.get_ensemble(ensemble_id), name) media_type = accept if accept is not None else "text/csv" if media_type == "application/x-parquet": @@ -143,6 +146,7 @@ def get_ensemble_responses( def get_std_dev( *, storage: Storage = DEFAULT_STORAGE, ensemble_id: UUID, key: str, z: int ) -> Response: + key = unquote(key) ensemble = storage.get_ensemble(ensemble_id) try: da = ensemble.calculate_std_dev_for_parameter(key)["values"] diff --git a/src/ert/gui/tools/plot/plot_api.py b/src/ert/gui/tools/plot/plot_api.py index faf503b80a8..8742e240a7e 100644 --- a/src/ert/gui/tools/plot/plot_api.py +++ b/src/ert/gui/tools/plot/plot_api.py @@ -4,6 +4,7 @@ from itertools import combinations as combi from json.decoder import JSONDecodeError from typing import Any, Dict, List, NamedTuple, Optional +from urllib.parse import quote import httpx import numpy as np @@ -37,6 +38,10 @@ class EnsembleObject: ) +def escape(s): + return quote(quote(quote(s, safe=""))) + + class PlotApi: def __init__(self) -> None: self._all_ensembles: Optional[List[EnsembleObject]] = None @@ -163,7 +168,7 @@ def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame: with StorageService.session() as client: response = client.get( - f"/ensembles/{ensemble.id}/records/{key}", + f"/ensembles/{ensemble.id}/records/{escape(key)}", headers={"accept": "application/x-parquet"}, timeout=self._timeout, ) @@ -196,7 +201,7 @@ def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFram with StorageService.session() as client: response = client.get( - f"/ensembles/{ensemble.id}/records/{key}/observations", + f"/ensembles/{ensemble.id}/records/{escape(key)}/observations", timeout=self._timeout, ) self._check_response(response) @@ -261,7 +266,7 @@ def std_dev_for_parameter( with StorageService.session() as client: response = client.get( - f"/ensembles/{ensemble.id}/records/{key}/std_dev", + f"/ensembles/{ensemble.id}/records/{escape(key)}/std_dev", params={"z": z}, timeout=self._timeout, ) diff --git a/tests/ert/unit_tests/dark_storage/test_dark_storage_state.py b/tests/ert/unit_tests/dark_storage/test_dark_storage_state.py index a81905f37e0..e0f453c1322 100644 --- a/tests/ert/unit_tests/dark_storage/test_dark_storage_state.py +++ b/tests/ert/unit_tests/dark_storage/test_dark_storage_state.py @@ -1,11 +1,12 @@ import io import os +from urllib.parse import quote from uuid import UUID import hypothesis.strategies as st import pandas as pd import pytest -from hypothesis import assume, settings +from hypothesis import assume from hypothesis.stateful import rule from starlette.testclient import TestClient @@ -14,7 +15,10 @@ from tests.ert.unit_tests.storage.test_local_storage import StatefulStorageTest -@settings(max_examples=1000) +def escape(s): + return quote(quote(quote(s, safe=""))) + + class DarkStorageStateTest(StatefulStorageTest): def __init__(self): super().__init__() @@ -62,7 +66,6 @@ def get_responses_through_client(self, model_ensemble): @rule(model_ensemble=StatefulStorageTest.ensembles, data=st.data()) def get_response_csv_through_client(self, model_ensemble, data): assume(model_ensemble.response_values) - print("Hit it!") response_key, response_name = data.draw( st.sampled_from( [ @@ -75,16 +78,17 @@ def get_response_csv_through_client(self, model_ensemble, data): df = pd.read_parquet( io.BytesIO( self.client.get( - f"/ensembles/{model_ensemble.uuid}/records/{response_name}", + f"/ensembles/{model_ensemble.uuid}/records/{escape(response_name)}", headers={"accept": "application/x-parquet"}, ).content ) ) - assert set(df.columns) == set( - model_ensemble.response_values[response_key] + assert {dt[:10] for dt in df.columns} == { + str(dt)[:10] + for dt in model_ensemble.response_values[response_key] .sel(name=response_name)["time"] .values - ) + } def teardown(self): super().teardown() diff --git a/tests/ert/unit_tests/gui/tools/plot/conftest.py b/tests/ert/unit_tests/gui/tools/plot/conftest.py index c5eba03ab5a..dc59088248b 100644 --- a/tests/ert/unit_tests/gui/tools/plot/conftest.py +++ b/tests/ert/unit_tests/gui/tools/plot/conftest.py @@ -209,9 +209,9 @@ def mocked_requests_get(*args, **kwargs): records = { "/ensembles/ens_id_3/records/FOPR": summary_parquet_data, - "/ensembles/ens_id_3/records/BPR:1,3,8": summary_parquet_data, - "/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:BPR_138_PERSISTENCE": parameter_parquet_data, - "/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE": parameter_parquet_data, + "/ensembles/ens_id_3/records/BPR%25253A1%25252C3%25252C8": summary_parquet_data, + "/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253ABPR_138_PERSISTENCE": parameter_parquet_data, + "/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253AOP1_DIVERGENCE_SCALE": parameter_parquet_data, "/ensembles/ens_id_3/records/SNAKE_OIL_WPR_DIFF@199": gen_parquet_data, "/ensembles/ens_id_3/records/FOPRH": history_parquet_data, } diff --git a/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py b/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py index 75c031b1a1a..f77c252d19a 100644 --- a/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py +++ b/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py @@ -1,9 +1,19 @@ +from datetime import datetime + import httpx +import numpy as np import pandas as pd import pytest +import xarray as xr from pandas.testing import assert_frame_equal - -from ert.gui.tools.plot.plot_api import PlotApiKeyDefinition +from starlette.testclient import TestClient + +from ert.config import SummaryConfig +from ert.dark_storage.app import app +from ert.dark_storage.enkf import update_storage +from ert.gui.tools.plot.plot_api import PlotApi, PlotApiKeyDefinition +from ert.services import StorageService +from ert.storage import open_storage from tests.ert.unit_tests.gui.tools.plot.conftest import MockResponse @@ -146,3 +156,41 @@ def test_plot_api_request_errors(api): with pytest.raises(httpx.RequestError): api.data_for_key(ensemble.id, "should_not_be_there") + + +def test_plot_api_handles_urlescape(tmp_path, monkeypatch): + with open_storage(tmp_path / "storage", mode="w") as storage: + monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup") + monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path) + update_storage() + client = TestClient(app) + monkeypatch.setattr(StorageService, "session", lambda: client) + api = PlotApi() + key = "WBHP:46/3-7S" + experiment = storage.create_experiment( + parameters=[], + responses=[ + SummaryConfig( + name="summary", + input_files=["CASE.UNSMRY", "CASE.SMSPEC"], + keys=[key], + ) + ], + ) + ensemble = experiment.create_ensemble(ensemble_size=1, name="ensemble") + assert api.data_for_key(str(ensemble.id), "WBHP:46/3-7S").empty + ensemble.save_response( + "summary", + xr.Dataset( + {"values": (["name", "time"], np.array([[1.0]]))}, + coords={"time": [datetime(year=2024, month=10, day=4)], "name": [key]}, + ), + 0, + ) + expected = pd.DataFrame({"2024-10-04": [1.0], "Realization": [0]}) + expected.set_index("Realization", inplace=True) + + assert ( + api.data_for_key(str(ensemble.id), "WBHP:46/3-7S").to_csv() + == expected.to_csv() + )