Skip to content

Commit

Permalink
Escape slashes in plotapi
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Oct 11, 2024
1 parent bc8b1fb commit daffeed
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/ert/dark_storage/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def data_for_key(
"summary", tuple(ensemble.get_realization_list_with_responses("summary"))
)
summary_keys = summary_data["response_key"].unique().to_list()
except (ValueError, KeyError):
except (ValueError, KeyError, polars.exceptions.ColumnNotFoundError):
summary_data = polars.DataFrame()
summary_keys = []

Expand Down
4 changes: 4 additions & 0 deletions src/ert/dark_storage/endpoints/records.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
from typing import Any, Dict, List, Mapping, Union
from urllib.parse import unquote
from uuid import UUID, uuid4

import numpy as np
Expand Down Expand Up @@ -34,6 +35,7 @@ async def get_record_observations(
ensemble_id: UUID,
response_name: str,
) -> List[js.ObservationOut]:
response_name = unquote(response_name)
ensemble = storage.get_ensemble(ensemble_id)
obs_keys = get_observation_keys_for_response(ensemble, response_name)
obss = get_observations_for_obs_keys(ensemble, obs_keys)
Expand Down Expand Up @@ -74,6 +76,7 @@ async def get_ensemble_record(
ensemble_id: UUID,
accept: Annotated[Union[str, None], Header()] = None,
) -> Any:
name = unquote(name)
dataframe = data_for_key(storage.get_ensemble(ensemble_id), name)
media_type = accept if accept is not None else "text/csv"
if media_type == "application/x-parquet":
Expand Down Expand Up @@ -153,6 +156,7 @@ def get_ensemble_responses(
def get_std_dev(
*, storage: Storage = DEFAULT_STORAGE, ensemble_id: UUID, key: str, z: int
) -> Response:
key = unquote(key)
ensemble = storage.get_ensemble(ensemble_id)
try:
da = ensemble.calculate_std_dev_for_parameter(key)["values"]
Expand Down
14 changes: 11 additions & 3 deletions src/ert/gui/tools/plot/plot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import combinations as combi
from json.decoder import JSONDecodeError
from typing import Any, Dict, List, NamedTuple, Optional
from urllib.parse import quote

import httpx
import numpy as np
Expand Down Expand Up @@ -42,6 +43,10 @@ def __init__(self) -> None:
self._all_ensembles: Optional[List[EnsembleObject]] = None
self._timeout = 120

@staticmethod
def escape(s: str) -> str:
return quote(quote(s, safe=""))

def _get_ensemble_by_id(self, id: str) -> Optional[EnsembleObject]:
for ensemble in self.get_all_ensembles():
if ensemble.id == id:
Expand Down Expand Up @@ -162,8 +167,9 @@ def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame:
return pd.DataFrame()

with StorageService.session() as client:
print(key)
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}",
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}",
headers={"accept": "application/x-parquet"},
timeout=self._timeout,
)
Expand Down Expand Up @@ -195,8 +201,9 @@ def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFram
continue

with StorageService.session() as client:
print(key)
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}/observations",
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/observations",
timeout=self._timeout,
)
self._check_response(response)
Expand Down Expand Up @@ -260,8 +267,9 @@ def std_dev_for_parameter(
return np.array([])

with StorageService.session() as client:
print(key)
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}/std_dev",
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/std_dev",
params={"z": z},
timeout=self._timeout,
)
Expand Down
2 changes: 1 addition & 1 deletion src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def get_summary_keyset(self) -> List[str]:
)

return sorted(summary_data["response_key"].unique().to_list())
except (ValueError, KeyError):
except (ValueError, KeyError, polars.ColumnNotFoundError):
return []

def _load_single_dataset(
Expand Down
2 changes: 2 additions & 0 deletions tests/ert/unit_tests/dark_storage/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ert.cli.main import run_cli
from ert.dark_storage import enkf
from ert.dark_storage.app import app
from ert.dark_storage.enkf import update_storage
from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE


Expand Down Expand Up @@ -52,6 +53,7 @@ def poly_example_tmp_dir(poly_example_tmp_dir_shared):
def dark_storage_client(monkeypatch):
with dark_storage_app_(monkeypatch) as dark_app:
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", "storage")
update_storage()
with TestClient(dark_app) as client:
yield client

Expand Down
30 changes: 19 additions & 11 deletions tests/ert/unit_tests/dark_storage/test_dark_storage_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import io
import os
from urllib.parse import quote
from uuid import UUID

import hypothesis.strategies as st
import pandas as pd
import pytest
from hypothesis import assume, settings
from hypothesis import assume
from hypothesis.stateful import rule
from starlette.testclient import TestClient

Expand All @@ -14,7 +15,10 @@
from tests.ert.unit_tests.storage.test_local_storage import StatefulStorageTest


@settings(max_examples=1000)
def escape(s):
return quote(quote(quote(s, safe="")))


class DarkStorageStateTest(StatefulStorageTest):
def __init__(self):
super().__init__()
Expand All @@ -40,9 +44,11 @@ def get_experiments_through_client(self):
@rule(model_experiment=StatefulStorageTest.experiments)
def get_observations_through_client(self, model_experiment):
response = self.client.get(f"/experiments/{model_experiment.uuid}/observations")
assert {r["name"] for r in response.json()} == set(
model_experiment.observations.keys()
)
assert {r["name"] for r in response.json()} == {
key
for _, ds in model_experiment.observations.items()
for key in ds["observation_key"]
}

@rule(model_experiment=StatefulStorageTest.experiments)
def get_ensembles_through_client(self, model_experiment):
Expand All @@ -55,14 +61,15 @@ def get_ensembles_through_client(self, model_experiment):
def get_responses_through_client(self, model_ensemble):
response = self.client.get(f"/ensembles/{model_ensemble.uuid}/responses")
response_names = {
k for r in model_ensemble.response_values.values() for k in r["name"].values
k
for r in model_ensemble.response_values.values()
for k in r["response_key"]
}
assert set(response.json().keys()) == response_names

@rule(model_ensemble=StatefulStorageTest.ensembles, data=st.data())
def get_response_csv_through_client(self, model_ensemble, data):
assume(model_ensemble.response_values)
print("Hit it!")
response_key, response_name = data.draw(
st.sampled_from(
[
Expand All @@ -75,16 +82,17 @@ def get_response_csv_through_client(self, model_ensemble, data):
df = pd.read_parquet(
io.BytesIO(
self.client.get(
f"/ensembles/{model_ensemble.uuid}/records/{response_name}",
f"/ensembles/{model_ensemble.uuid}/records/{escape(response_name)}",
headers={"accept": "application/x-parquet"},
).content
)
)
assert set(df.columns) == set(
model_ensemble.response_values[response_key]
assert {dt[:10] for dt in df.columns} == {
str(dt)[:10]
for dt in model_ensemble.response_values[response_key]
.sel(name=response_name)["time"]
.values
)
}

def teardown(self):
super().teardown()
Expand Down
6 changes: 3 additions & 3 deletions tests/ert/unit_tests/gui/tools/plot/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def mocked_requests_get(*args, **kwargs):

records = {
"/ensembles/ens_id_3/records/FOPR": summary_parquet_data,
"/ensembles/ens_id_3/records/BPR:1,3,8": summary_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:BPR_138_PERSISTENCE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE": parameter_parquet_data,
"/ensembles/ens_id_3/records/BPR%25253A1%25252C3%25252C8": summary_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253ABPR_138_PERSISTENCE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253AOP1_DIVERGENCE_SCALE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_WPR_DIFF@199": gen_parquet_data,
"/ensembles/ens_id_3/records/FOPRH": history_parquet_data,
}
Expand Down
91 changes: 89 additions & 2 deletions tests/ert/unit_tests/gui/tools/plot/test_plot_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
from datetime import datetime
from textwrap import dedent
from urllib.parse import quote

import httpx
import pandas as pd
import polars
import pytest
from pandas.testing import assert_frame_equal

from ert.gui.tools.plot.plot_api import PlotApiKeyDefinition
from starlette.testclient import TestClient

from ert.config import SummaryConfig
from ert.dark_storage.app import app
from ert.dark_storage.enkf import update_storage
from ert.gui.tools.plot.plot_api import PlotApi, PlotApiKeyDefinition
from ert.services import StorageService
from ert.storage import open_storage
from tests.ert.unit_tests.gui.tools.plot.conftest import MockResponse


@pytest.fixture(autouse=True)
def use_testclient(monkeypatch):
client = TestClient(app)
monkeypatch.setattr(StorageService, "session", lambda: client)

def test_escape(s: str) -> str:
"""
Workaround for issue with TestClient:
https://github.com/encode/starlette/issues/1060
"""
print("TESTESCAPING")
return quote(quote(quote(s, safe="")))

PlotApi.escape = test_escape


def test_key_def_structure(api):
key_defs = api.all_data_type_keys()
fopr = next(x for x in key_defs if x.key == "FOPR")
Expand Down Expand Up @@ -146,3 +173,63 @@ def test_plot_api_request_errors(api):

with pytest.raises(httpx.RequestError):
api.data_for_key(ensemble.id, "should_not_be_there")


def test_plot_api_handles_urlescape(tmp_path, monkeypatch):
with open_storage(tmp_path / "storage", mode="w") as storage:
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path)
update_storage()
api = PlotApi()
key = "WBHP:46/3-7S"
date = datetime(year=2024, month=10, day=4)
experiment = storage.create_experiment(
parameters=[],
responses=[
SummaryConfig(
name="summary",
input_files=["CASE.UNSMRY", "CASE.SMSPEC"],
keys=[key],
)
],
observations={
"summary": polars.DataFrame(
{
"response_key": key,
"observation_key": "sumobs",
"time": polars.Series([date]).dt.cast_time_unit("ms"),
"observations": polars.Series([1.0], dtype=polars.Float32),
"std": polars.Series([1.0], dtype=polars.Float32),
}
)
},
)
ensemble = experiment.create_ensemble(ensemble_size=1, name="ensemble")
assert api.data_for_key(str(ensemble.id), key).empty
df = polars.DataFrame(
{
"response_key": [key],
"time": [polars.Series([date]).dt.cast_time_unit("ms")],
"values": [polars.Series([1.0], dtype=polars.Float32)],
}
)
df = df.explode("values", "time")
ensemble.save_response(
"summary",
df,
0,
)
assert api.data_for_key(str(ensemble.id), key).to_csv() == dedent(
"""\
Realization,2024-10-04
0,1.0
"""
)
assert api.observations_for_key([str(ensemble.id)], key).to_csv() == dedent(
"""\
,0
STD,1.0
OBS,1.0
key_index,2024-10-04 00:00:00
"""
)
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ class Experiment:
ensembles: Dict[UUID, Ensemble] = field(default_factory=dict)
parameters: List[ParameterConfig] = field(default_factory=list)
responses: List[ResponseConfig] = field(default_factory=list)
observations: Dict[str, xr.Dataset] = field(default_factory=dict)
observations: Dict[str, polars.DataFrame] = field(default_factory=dict)


class StatefulStorageTest(RuleBasedStateMachine):
Expand Down

0 comments on commit daffeed

Please sign in to comment.