Skip to content

Commit

Permalink
Escape slashes in plotapi
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Oct 4, 2024
1 parent e546621 commit 5bcc770
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 15 deletions.
4 changes: 4 additions & 0 deletions src/ert/dark_storage/endpoints/records.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
from typing import Any, Dict, List, Mapping, Union
from urllib.parse import unquote
from uuid import UUID, uuid4

import numpy as np
Expand Down Expand Up @@ -33,6 +34,7 @@ async def get_record_observations(
ensemble_id: UUID,
response_name: str,
) -> List[js.ObservationOut]:
response_name = unquote(response_name)
ensemble = storage.get_ensemble(ensemble_id)
obs_keys = get_observation_keys_for_response(ensemble, response_name)
obss = get_observations_for_obs_keys(ensemble, obs_keys)
Expand Down Expand Up @@ -72,6 +74,7 @@ async def get_ensemble_record(
ensemble_id: UUID,
accept: Annotated[Union[str, None], Header()] = None,
) -> Any:
name = unquote(name)
dataframe = data_for_key(storage.get_ensemble(ensemble_id), name)
media_type = accept if accept is not None else "text/csv"
if media_type == "application/x-parquet":
Expand Down Expand Up @@ -143,6 +146,7 @@ def get_ensemble_responses(
def get_std_dev(
*, storage: Storage = DEFAULT_STORAGE, ensemble_id: UUID, key: str, z: int
) -> Response:
key = unquote(key)
ensemble = storage.get_ensemble(ensemble_id)
try:
da = ensemble.calculate_std_dev_for_parameter(key)["values"]
Expand Down
11 changes: 8 additions & 3 deletions src/ert/gui/tools/plot/plot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import combinations as combi
from json.decoder import JSONDecodeError
from typing import Any, Dict, List, NamedTuple, Optional
from urllib.parse import quote

import httpx
import numpy as np
Expand Down Expand Up @@ -37,6 +38,10 @@ class EnsembleObject:
)


def escape(s):

Check failure on line 41 in src/ert/gui/tools/plot/plot_api.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a type annotation
return quote(quote(quote(s, safe="")))


class PlotApi:
def __init__(self) -> None:
self._all_ensembles: Optional[List[EnsembleObject]] = None
Expand Down Expand Up @@ -163,7 +168,7 @@ def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame:

with StorageService.session() as client:
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}",
f"/ensembles/{ensemble.id}/records/{escape(key)}",

Check failure on line 171 in src/ert/gui/tools/plot/plot_api.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "escape" in typed context
headers={"accept": "application/x-parquet"},
timeout=self._timeout,
)
Expand Down Expand Up @@ -196,7 +201,7 @@ def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFram

with StorageService.session() as client:
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}/observations",
f"/ensembles/{ensemble.id}/records/{escape(key)}/observations",

Check failure on line 204 in src/ert/gui/tools/plot/plot_api.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "escape" in typed context
timeout=self._timeout,
)
self._check_response(response)
Expand Down Expand Up @@ -261,7 +266,7 @@ def std_dev_for_parameter(

with StorageService.session() as client:
response = client.get(
f"/ensembles/{ensemble.id}/records/{key}/std_dev",
f"/ensembles/{ensemble.id}/records/{escape(key)}/std_dev",

Check failure on line 269 in src/ert/gui/tools/plot/plot_api.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "escape" in typed context
params={"z": z},
timeout=self._timeout,
)
Expand Down
18 changes: 11 additions & 7 deletions tests/ert/unit_tests/dark_storage/test_dark_storage_state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import io
import os
from urllib.parse import quote
from uuid import UUID

import hypothesis.strategies as st
import pandas as pd
import pytest
from hypothesis import assume, settings
from hypothesis import assume
from hypothesis.stateful import rule
from starlette.testclient import TestClient

Expand All @@ -14,7 +15,10 @@
from tests.ert.unit_tests.storage.test_local_storage import StatefulStorageTest


@settings(max_examples=1000)
def escape(s):
return quote(quote(quote(s, safe="")))


class DarkStorageStateTest(StatefulStorageTest):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -62,7 +66,6 @@ def get_responses_through_client(self, model_ensemble):
@rule(model_ensemble=StatefulStorageTest.ensembles, data=st.data())
def get_response_csv_through_client(self, model_ensemble, data):
assume(model_ensemble.response_values)
print("Hit it!")
response_key, response_name = data.draw(
st.sampled_from(
[
Expand All @@ -75,16 +78,17 @@ def get_response_csv_through_client(self, model_ensemble, data):
df = pd.read_parquet(
io.BytesIO(
self.client.get(
f"/ensembles/{model_ensemble.uuid}/records/{response_name}",
f"/ensembles/{model_ensemble.uuid}/records/{escape(response_name)}",
headers={"accept": "application/x-parquet"},
).content
)
)
assert set(df.columns) == set(
model_ensemble.response_values[response_key]
assert {dt[:10] for dt in df.columns} == {
str(dt)[:10]
for dt in model_ensemble.response_values[response_key]
.sel(name=response_name)["time"]
.values
)
}

def teardown(self):
super().teardown()
Expand Down
6 changes: 3 additions & 3 deletions tests/ert/unit_tests/gui/tools/plot/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def mocked_requests_get(*args, **kwargs):

records = {
"/ensembles/ens_id_3/records/FOPR": summary_parquet_data,
"/ensembles/ens_id_3/records/BPR:1,3,8": summary_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:BPR_138_PERSISTENCE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE": parameter_parquet_data,
"/ensembles/ens_id_3/records/BPR%25253A1%25252C3%25252C8": summary_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253ABPR_138_PERSISTENCE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253AOP1_DIVERGENCE_SCALE": parameter_parquet_data,
"/ensembles/ens_id_3/records/SNAKE_OIL_WPR_DIFF@199": gen_parquet_data,
"/ensembles/ens_id_3/records/FOPRH": history_parquet_data,
}
Expand Down
52 changes: 50 additions & 2 deletions tests/ert/unit_tests/gui/tools/plot/test_plot_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from datetime import datetime

import httpx
import numpy as np
import pandas as pd
import pytest
import xarray as xr
from pandas.testing import assert_frame_equal

from ert.gui.tools.plot.plot_api import PlotApiKeyDefinition
from starlette.testclient import TestClient

from ert.config import SummaryConfig
from ert.dark_storage.app import app
from ert.dark_storage.enkf import update_storage
from ert.gui.tools.plot.plot_api import PlotApi, PlotApiKeyDefinition
from ert.services import StorageService
from ert.storage import open_storage
from tests.ert.unit_tests.gui.tools.plot.conftest import MockResponse


Expand Down Expand Up @@ -146,3 +156,41 @@ def test_plot_api_request_errors(api):

with pytest.raises(httpx.RequestError):
api.data_for_key(ensemble.id, "should_not_be_there")


def test_plot_api_handles_urlescape(tmp_path, monkeypatch):
with open_storage(tmp_path / "storage", mode="w") as storage:
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path)
update_storage()
client = TestClient(app)
monkeypatch.setattr(StorageService, "session", lambda: client)
api = PlotApi()
key = "WBHP:46/3-7S"
experiment = storage.create_experiment(
parameters=[],
responses=[
SummaryConfig(
name="summary",
input_files=["CASE.UNSMRY", "CASE.SMSPEC"],
keys=[key],
)
],
)
ensemble = experiment.create_ensemble(ensemble_size=1, name="ensemble")
assert api.data_for_key(str(ensemble.id), "WBHP:46/3-7S").empty
ensemble.save_response(
"summary",
xr.Dataset(
{"values": (["name", "time"], np.array([[1.0]]))},
coords={"time": [datetime(year=2024, month=10, day=4)], "name": [key]},
),
0,
)
expected = pd.DataFrame({"2024-10-04": [1.0], "Realization": [0]})
expected.set_index("Realization", inplace=True)

assert (
api.data_for_key(str(ensemble.id), "WBHP:46/3-7S").to_csv()
== expected.to_csv()
)

0 comments on commit 5bcc770

Please sign in to comment.