diff --git a/mpcontribs-client/mpcontribs/client/__init__.py b/mpcontribs-client/mpcontribs/client/__init__.py index 664cbf4b0..a71f7929f 100644 --- a/mpcontribs-client/mpcontribs/client/__init__.py +++ b/mpcontribs-client/mpcontribs/client/__init__.py @@ -7,6 +7,7 @@ import gzip import warnings import pandas as pd +import numpy as np import plotly.io as pio import itertools import functools @@ -18,7 +19,6 @@ from math import isclose from semantic_version import Version from requests.exceptions import RequestException -from bravado_core.param import Param from bson.objectid import ObjectId from typing import Union, Type, List from tqdm.auto import tqdm @@ -39,7 +39,9 @@ from bravado.requests_client import RequestsClient from bravado.swagger_model import Loader from bravado.config import bravado_config_from_config_dict -from bravado_core.spec import Spec +from bravado_core.spec import Spec, build_api_serving_url, _identity +from bravado_core.model import model_discovery +from bravado_core.resource import build_resources from bravado.exception import HTTPNotFound from bravado_core.validate import validate_object from json2html import Json2Html @@ -87,7 +89,6 @@ j2h = Json2Html() pd.options.plotting.backend = "plotly" -pd.set_option('mode.use_inf_as_na', True) pio.templates.default = "simple_white" warnings.formatwarning = lambda msg, *args, **kwargs: f"{msg}\n" warnings.filterwarnings("default", category=DeprecationWarning, module=__name__) @@ -203,6 +204,16 @@ def validate_url(url_string, qualifying=("scheme", "netloc")): url_format = SwaggerFormat( format="url", to_wire=str, to_python=str, validate=validate_url, description="URL", ) +bravado_config_dict = { + "validate_responses": False, + "use_models": False, + "include_missing_properties": False, + "formats": [email_format, url_format], +} +bravado_config = bravado_config_from_config_dict(bravado_config_dict) +for key in set(bravado_config._fields).intersection(set(bravado_config_dict)): + del bravado_config_dict[key] +bravado_config_dict["bravado"] = bravado_config # https://stackoverflow.com/a/8991553 @@ -372,6 +383,7 @@ def from_dict(cls, dct: dict): def _clean(self): """clean the dataframe""" + self.replace([np.inf, -np.inf], np.nan, inplace=True) self.fillna('', inplace=True) self.index = self.index.astype(str) for col in self.columns: @@ -641,9 +653,39 @@ def _run_futures(futures, total: int = 0, timeout: int = -1, desc=None, disable= @functools.lru_cache(maxsize=1000) def _load(protocol, host, headers_json, project, version): + spec_dict = _raw_specs(protocol, host, version) headers = ujson.loads(headers_json) + + if not spec_dict["paths"]: + url = f"{protocol}://{host}" + origin_url = f"{url}/apispec.json" + http_client = RequestsClient() + http_client.session.headers.update(headers) + swagger_spec = Spec.from_dict(spec_dict, origin_url, http_client, bravado_config_dict) + http_client.session.close() + return swagger_spec + + # retrieve list of projects accessible to user + query = {"name": project} if project else {} + query["_fields"] = ["name"] + url = f"{protocol}://{host}" + resp = requests.get(f"{url}/projects/", params=query, headers=headers).json() + + if not resp or not resp["data"]: + raise MPContribsClientError(f"Failed to load projects for query {query}!") + + if project and not resp["data"]: + raise MPContribsClientError(f"{project} doesn't exist, or access denied!") + + projects = sorted(d["name"] for d in resp["data"]) + projects_json = ujson.dumps(projects) + # expand regex-based query parameters for `data` columns + return _expand_params(protocol, host, version, projects_json) + + +@functools.lru_cache(maxsize=1) +def _raw_specs(protocol, host, version): http_client = RequestsClient() - http_client.session.headers.update(headers) url = f"{protocol}://{host}" origin_url = f"{url}/apispec.json" url4fn = origin_url.replace("apispec", f"apispec-{version}").encode('utf-8') @@ -668,35 +710,20 @@ def _load(protocol, host, headers_json, project, version): spec_dict["host"] = host spec_dict["schemes"] = [protocol] - - config = { - "validate_responses": False, - "use_models": False, - "include_missing_properties": False, - "formats": [email_format, url_format], - } - bravado_config = bravado_config_from_config_dict(config) - for key in set(bravado_config._fields).intersection(set(config)): - del config[key] - config["bravado"] = bravado_config - swagger_spec = Spec.from_dict(spec_dict, origin_url, http_client, config) - - if not spec_dict["paths"]: - return swagger_spec - - # expand regex-based query parameters for `data` columns - query = {"name": project} if project else {} - query["_fields"] = ["columns"] - resp = http_client.session.get(f"{url}/projects/", params=query).json() http_client.session.close() + return spec_dict - if not resp or not resp["data"]: - raise MPContribsClientError(f"Failed to load projects for query {query}!") - - if project and not resp["data"]: - raise MPContribsClientError(f"{project} doesn't exist, or access denied!") +@functools.lru_cache(maxsize=100) +def _expand_params(protocol, host, version, projects_json): columns = {"string": [], "number": []} + projects = ujson.loads(projects_json) + query = {"project__in": projects} + query["_fields"] = ["columns"] + url = f"{protocol}://{host}" + http_client = RequestsClient() + http_client.session.headers["Content-Type"] = "application/json" + resp = http_client.session.get(f"{url}/projects/", params=query).json() for proj in resp["data"]: for column in proj["columns"]: @@ -708,29 +735,50 @@ def _load(protocol, host, headers_json, project, version): col = f"{col}__value" columns["number"].append(col) - resource = swagger_spec.resources["contributions"] + spec_dict = _raw_specs(protocol, host, version) + resource = spec_dict["paths"]["/contributions/"]["get"] + raw_params = resource.pop("parameters") + params = {} - for operation_id, operation in resource.operations.items(): - for pn in list(operation.params.keys()): - if pn.startswith("data_"): - param = operation.params.pop(pn) - op = param.name.rsplit('$__', 1)[-1] - typ = param.param_spec.get("type") - key = "number" if typ == "number" else "string" + for param in raw_params: + if param["name"].startswith("^data__"): + op = param["name"].rsplit('$__', 1)[-1] + typ = param["type"] + key = "number" if typ == "number" else "string" - for column in columns[key]: - param_name = f"{column}__{op}" + for column in columns[key]: + param_name = f"{column}__{op}" + if param_name not in params: param_spec = { - k: v - for k, v in param.param_spec.items() - if k != "description" + k: v for k, v in param.items() + if k not in ["name", "description"] } param_spec["name"] = param_name - operation.params[param_name] = Param( - swagger_spec, operation, param_spec - ) + params[param_name] = param_spec + else: + params[param["name"]] = param + + resource["parameters"] = list(params.values()) - return swagger_spec + origin_url = f"{url}/apispec.json" + spec = Spec(spec_dict, origin_url, http_client, bravado_config_dict) + model_discovery(spec) + + if spec.config['internally_dereference_refs']: + spec.deref = _identity + spec._internal_spec_dict = spec.deref_flattened_spec + + for user_defined_format in spec.config['formats']: + spec.register_format(user_defined_format) + + spec.resources = build_resources(spec) + spec.api_url = build_api_serving_url( + spec_dict=spec.spec_dict, + origin_url=spec.origin_url, + use_spec_url_for_base_path=spec.config['use_spec_url_for_base_path'], + ) + http_client.session.close() + return spec @functools.lru_cache(maxsize=1)