diff --git a/mpcontribs-client/mpcontribs/client/__init__.py b/mpcontribs-client/mpcontribs/client/__init__.py index a71f7929f..7a8bab565 100644 --- a/mpcontribs-client/mpcontribs/client/__init__.py +++ b/mpcontribs-client/mpcontribs/client/__init__.py @@ -59,6 +59,8 @@ from pint.errors import DimensionalityError from tempfile import gettempdir from plotly.express._chart_types import line as line_chart +from cachetools import cached, LRUCache +from cachetools.keys import hashkey RETRIES = 3 MAX_WORKERS = 3 @@ -680,7 +682,9 @@ def _load(protocol, host, headers_json, project, version): projects = sorted(d["name"] for d in resp["data"]) projects_json = ujson.dumps(projects) # expand regex-based query parameters for `data` columns - return _expand_params(protocol, host, version, projects_json) + spec = _expand_params(protocol, host, version, projects_json, apikey=headers["x-api-key"]) + spec.http_client.session.headers.update(headers) + return spec @functools.lru_cache(maxsize=1) @@ -714,15 +718,22 @@ def _raw_specs(protocol, host, version): return spec_dict -@functools.lru_cache(maxsize=100) -def _expand_params(protocol, host, version, projects_json): +@cached( + cache=LRUCache(maxsize=100), + key=lambda protocol, host, version, projects_json, **kwargs: hashkey( + protocol, host, version, projects_json + ) +) +def _expand_params(protocol, host, version, projects_json, apikey=None): columns = {"string": [], "number": []} projects = ujson.loads(projects_json) - query = {"project__in": projects} + query = {"project__in": ",".join(projects)} query["_fields"] = ["columns"] url = f"{protocol}://{host}" http_client = RequestsClient() http_client.session.headers["Content-Type"] = "application/json" + if apikey: + http_client.session.headers["X-Api-Key"] = apikey resp = http_client.session.get(f"{url}/projects/", params=query).json() for proj in resp["data"]: diff --git a/mpcontribs-client/setup.py b/mpcontribs-client/setup.py index bbf50728f..06f6be66a 100644 --- a/mpcontribs-client/setup.py +++ b/mpcontribs-client/setup.py @@ -38,6 +38,7 @@ def local_version(version): "tqdm", "ujson", "semantic-version", + "cachetools", ], extras_require={ "dev": [