Skip to content

Commit

Permalink
Merge pull request #1675 from materialsproject/dev
Browse files Browse the repository at this point in the history
client: cache specs by projects, fix mem leak
  • Loading branch information
tschaume authored Nov 27, 2023
2 parents 6d63e27 + a7c1f8c commit 42b99c0
Showing 1 changed file with 94 additions and 46 deletions.
140 changes: 94 additions & 46 deletions mpcontribs-client/mpcontribs/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import gzip
import warnings
import pandas as pd
import numpy as np
import plotly.io as pio
import itertools
import functools
Expand All @@ -18,7 +19,6 @@
from math import isclose
from semantic_version import Version
from requests.exceptions import RequestException
from bravado_core.param import Param
from bson.objectid import ObjectId
from typing import Union, Type, List
from tqdm.auto import tqdm
Expand All @@ -39,7 +39,9 @@
from bravado.requests_client import RequestsClient
from bravado.swagger_model import Loader
from bravado.config import bravado_config_from_config_dict
from bravado_core.spec import Spec
from bravado_core.spec import Spec, build_api_serving_url, _identity
from bravado_core.model import model_discovery
from bravado_core.resource import build_resources
from bravado.exception import HTTPNotFound
from bravado_core.validate import validate_object
from json2html import Json2Html
Expand Down Expand Up @@ -87,7 +89,6 @@

j2h = Json2Html()
pd.options.plotting.backend = "plotly"
pd.set_option('mode.use_inf_as_na', True)
pio.templates.default = "simple_white"
warnings.formatwarning = lambda msg, *args, **kwargs: f"{msg}\n"
warnings.filterwarnings("default", category=DeprecationWarning, module=__name__)
Expand Down Expand Up @@ -203,6 +204,16 @@ def validate_url(url_string, qualifying=("scheme", "netloc")):
url_format = SwaggerFormat(
format="url", to_wire=str, to_python=str, validate=validate_url, description="URL",
)
bravado_config_dict = {
"validate_responses": False,
"use_models": False,
"include_missing_properties": False,
"formats": [email_format, url_format],
}
bravado_config = bravado_config_from_config_dict(bravado_config_dict)
for key in set(bravado_config._fields).intersection(set(bravado_config_dict)):
del bravado_config_dict[key]
bravado_config_dict["bravado"] = bravado_config


# https://stackoverflow.com/a/8991553
Expand Down Expand Up @@ -372,6 +383,7 @@ def from_dict(cls, dct: dict):

def _clean(self):
"""clean the dataframe"""
self.replace([np.inf, -np.inf], np.nan, inplace=True)
self.fillna('', inplace=True)
self.index = self.index.astype(str)
for col in self.columns:
Expand Down Expand Up @@ -641,9 +653,39 @@ def _run_futures(futures, total: int = 0, timeout: int = -1, desc=None, disable=

@functools.lru_cache(maxsize=1000)
def _load(protocol, host, headers_json, project, version):
spec_dict = _raw_specs(protocol, host, version)
headers = ujson.loads(headers_json)

if not spec_dict["paths"]:
url = f"{protocol}://{host}"
origin_url = f"{url}/apispec.json"
http_client = RequestsClient()
http_client.session.headers.update(headers)
swagger_spec = Spec.from_dict(spec_dict, origin_url, http_client, bravado_config_dict)
http_client.session.close()
return swagger_spec

# retrieve list of projects accessible to user
query = {"name": project} if project else {}
query["_fields"] = ["name"]
url = f"{protocol}://{host}"
resp = requests.get(f"{url}/projects/", params=query, headers=headers).json()

if not resp or not resp["data"]:
raise MPContribsClientError(f"Failed to load projects for query {query}!")

if project and not resp["data"]:
raise MPContribsClientError(f"{project} doesn't exist, or access denied!")

projects = sorted(d["name"] for d in resp["data"])
projects_json = ujson.dumps(projects)
# expand regex-based query parameters for `data` columns
return _expand_params(protocol, host, version, projects_json)


@functools.lru_cache(maxsize=1)
def _raw_specs(protocol, host, version):
http_client = RequestsClient()
http_client.session.headers.update(headers)
url = f"{protocol}://{host}"
origin_url = f"{url}/apispec.json"
url4fn = origin_url.replace("apispec", f"apispec-{version}").encode('utf-8')
Expand All @@ -668,35 +710,20 @@ def _load(protocol, host, headers_json, project, version):

spec_dict["host"] = host
spec_dict["schemes"] = [protocol]

config = {
"validate_responses": False,
"use_models": False,
"include_missing_properties": False,
"formats": [email_format, url_format],
}
bravado_config = bravado_config_from_config_dict(config)
for key in set(bravado_config._fields).intersection(set(config)):
del config[key]
config["bravado"] = bravado_config
swagger_spec = Spec.from_dict(spec_dict, origin_url, http_client, config)

if not spec_dict["paths"]:
return swagger_spec

# expand regex-based query parameters for `data` columns
query = {"name": project} if project else {}
query["_fields"] = ["columns"]
resp = http_client.session.get(f"{url}/projects/", params=query).json()
http_client.session.close()
return spec_dict

if not resp or not resp["data"]:
raise MPContribsClientError(f"Failed to load projects for query {query}!")

if project and not resp["data"]:
raise MPContribsClientError(f"{project} doesn't exist, or access denied!")

@functools.lru_cache(maxsize=100)
def _expand_params(protocol, host, version, projects_json):
columns = {"string": [], "number": []}
projects = ujson.loads(projects_json)
query = {"project__in": projects}
query["_fields"] = ["columns"]
url = f"{protocol}://{host}"
http_client = RequestsClient()
http_client.session.headers["Content-Type"] = "application/json"
resp = http_client.session.get(f"{url}/projects/", params=query).json()

for proj in resp["data"]:
for column in proj["columns"]:
Expand All @@ -708,29 +735,50 @@ def _load(protocol, host, headers_json, project, version):
col = f"{col}__value"
columns["number"].append(col)

resource = swagger_spec.resources["contributions"]
spec_dict = _raw_specs(protocol, host, version)
resource = spec_dict["paths"]["/contributions/"]["get"]
raw_params = resource.pop("parameters")
params = {}

for operation_id, operation in resource.operations.items():
for pn in list(operation.params.keys()):
if pn.startswith("data_"):
param = operation.params.pop(pn)
op = param.name.rsplit('$__', 1)[-1]
typ = param.param_spec.get("type")
key = "number" if typ == "number" else "string"
for param in raw_params:
if param["name"].startswith("^data__"):
op = param["name"].rsplit('$__', 1)[-1]
typ = param["type"]
key = "number" if typ == "number" else "string"

for column in columns[key]:
param_name = f"{column}__{op}"
for column in columns[key]:
param_name = f"{column}__{op}"
if param_name not in params:
param_spec = {
k: v
for k, v in param.param_spec.items()
if k != "description"
k: v for k, v in param.items()
if k not in ["name", "description"]
}
param_spec["name"] = param_name
operation.params[param_name] = Param(
swagger_spec, operation, param_spec
)
params[param_name] = param_spec
else:
params[param["name"]] = param

resource["parameters"] = list(params.values())

return swagger_spec
origin_url = f"{url}/apispec.json"
spec = Spec(spec_dict, origin_url, http_client, bravado_config_dict)
model_discovery(spec)

if spec.config['internally_dereference_refs']:
spec.deref = _identity
spec._internal_spec_dict = spec.deref_flattened_spec

for user_defined_format in spec.config['formats']:
spec.register_format(user_defined_format)

spec.resources = build_resources(spec)
spec.api_url = build_api_serving_url(
spec_dict=spec.spec_dict,
origin_url=spec.origin_url,
use_spec_url_for_base_path=spec.config['use_spec_url_for_base_path'],
)
http_client.session.close()
return spec


@functools.lru_cache(maxsize=1)
Expand Down

0 comments on commit 42b99c0

Please sign in to comment.