Skip to content

Commit

Permalink
Merge pull request #1625 from materialsproject/dev
Browse files Browse the repository at this point in the history
cleanup, efficiency improvs, Attachments
  • Loading branch information
tschaume authored Sep 13, 2023
2 parents fedafe7 + 6ede7ae commit cb16d34
Showing 1 changed file with 165 additions and 88 deletions.
253 changes: 165 additions & 88 deletions mpcontribs-client/mpcontribs/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ def grouper(n, iterable):
yield chunk


def _compress(data):
data_json = ujson.dumps(data, indent=4).encode("utf-8")
content = gzip.compress(data_json)
return len(content), content


def get_session(session=None):
adapter_kwargs = dict(max_retries=Retry(
total=RETRIES,
Expand Down Expand Up @@ -259,6 +265,24 @@ def _response_hook(resp, *args, **kwargs):
resp.count = 0


def _chunk_by_size(items, max_size=0.95*MAX_BYTES):
buffer, buffer_size = [], 0

for idx, item in enumerate(items):
item_size = _compress(item)[0]

if buffer_size + item_size <= max_size:
buffer.append(item)
buffer_size += item_size
else:
yield buffer
buffer = [item]
buffer_size = item_size

if buffer_size > 0:
yield buffer


def visit(path, key, value):
if isinstance(value, dict) and "display" in value:
return key, value["display"]
Expand Down Expand Up @@ -346,6 +370,34 @@ def from_dict(cls, dct: dict):
ret.attrs = {k: v for k, v in dct["attrs"].items()}
return ret

def _clean(self):
"""clean the dataframe"""
self.fillna('', inplace=True)
self.index = self.index.astype(str)
for col in self.columns:
self[col] = self[col].astype(str)

def _attrs_as_dict(self):
name = self.attrs.get("name", "table")
title = self.attrs.get("title", name)
labels = self.attrs.get("labels", {})
index = self.index.name
variable = self.columns.name

if index and "index" not in labels:
labels["index"] = index
if variable and "variable" not in labels:
labels["variable"] = variable

return name, {"title": title, "labels": labels}

def as_dict(self):
"""Convert Table to plain dictionary"""
self._clean()
dct = self.to_dict(orient="split")
dct["name"], dct["attrs"] = self._attrs_as_dict()
return dct


class Structure(PmgStructure):
"""Wrapper class around pymatgen.Structure to provide display() and info()"""
Expand Down Expand Up @@ -427,17 +479,15 @@ def name(self) -> str:
return self["name"]

@classmethod
def from_data(cls, name: str, data: Union[list, dict]):
def from_data(cls, data: Union[list, dict], name: str = "attachment"):
"""Construct attachment from data dict or list
Args:
name (str): name for the attachment
data (list,dict): JSON-serializable data to go into the attachment
name (str): name for the attachment
"""
filename = name + ".json.gz"
data_json = ujson.dumps(data, indent=4).encode("utf-8")
content = gzip.compress(data_json)
size = len(content)
size, content = _compress(data)

if size > MAX_BYTES:
raise MPContribsClientError(f"{name} too large ({size} > {MAX_BYTES})!")
Expand All @@ -449,8 +499,8 @@ def from_data(cls, name: str, data: Union[list, dict]):
)

@classmethod
def from_textfile(cls, path: Union[Path, str]):
"""Construct attachment from (uncompressed) text file
def from_file(cls, path: Union[Path, str]):
"""Construct attachment from file
Args:
path (pathlib.Path, str): file path
Expand All @@ -461,12 +511,17 @@ def from_textfile(cls, path: Union[Path, str]):
typ = type(path)
raise MPContribsClientError(f"use pathlib.Path or str (is: {typ}).")

kind = guess(str(path))
supported = isinstance(kind, SUPPORTED_FILETYPES)
content = path.read_bytes()

try:
content = gzip.compress(content)
except Exception:
raise MPContribsClientError(f"Failed to gzip {path}.")
if not supported: # try to gzip text file
try:
content = gzip.compress(content)
except Exception:
raise MPContribsClientError(
f"{path} is not text file or {SUPPORTED_MIMES}."
)

size = len(content)

Expand All @@ -475,7 +530,7 @@ def from_textfile(cls, path: Union[Path, str]):

return cls(
name=path.name,
mime="application/gzip",
mime=kind.mime if supported else "application/gzip",
content=b64encode(content).decode("utf-8")
)

Expand All @@ -490,6 +545,63 @@ def from_dict(cls, dct: dict):
return cls((k, v) for k, v in dct.items() if k in keys)


class Attachments(list):
"""Wrapper class to handle attachments automatically"""
# TODO implement "plural" versions for Attachment methods

@classmethod
def from_list(cls, elements: list):
if not isinstance(elements, list):
raise MPContribsClientError("use list to init Attachments")

attachments = []

for element in elements:
if len(attachments) >= MAX_ELEMS:
raise MPContribsClientError(f"max {MAX_ELEMS} attachments reached")

if isinstance(element, Attachment):
# simply append, size check already performed
attachments.append(element)
elif isinstance(element, (list, dict)):
attachments += cls.from_data(element)
elif isinstance(element, (str, Path)):
# don't split files, user should use from_data to split
attm = Attachment.from_file(element)
attachments.append(attm)
else:
raise MPContribsClientError("invalid element for Attachments")

return attachments

@classmethod
def from_data(cls, data: Union[list, dict], prefix: str = "attachment"):
"""Construct list of attachments from data dict or list
Args:
data (list,dict): JSON-serializable data to go into the attachments
prefix (str): prefix for attachment name(s)
"""
try:
# try to make single attachment first
return [Attachment.from_data(data, name=prefix)]
except MPContribsClientError:
# chunk data into multiple attachments with < MAX_BYTES
if isinstance(data, dict):
raise NotImplementedError("dicts not supported yet")

attachments = []

for idx, chunk in enumerate(_chunk_by_size(data)):
if len(attachments) > MAX_ELEMS:
raise MPContribsClientError("list too large to split")

attm = Attachment.from_data(chunk, name=f"{prefix}{idx}")
attachments.append(attm)

return attachments


classes_map = {"structures": Structure, "tables": Table, "attachments": Attachment}


Expand Down Expand Up @@ -1880,15 +1992,17 @@ def submit_contributions(
continue

is_structure = isinstance(element, PmgStructure)
is_table = isinstance(element, pd.DataFrame)
is_attachment = isinstance(element, Path) or isinstance(element, Attachment)
is_table = isinstance(element, (pd.DataFrame, Table))
is_attachment = isinstance(element, (str, Path, Attachment))
if component == "structures" and not is_structure:
raise MPContribsClientError(f"Use pymatgen Structure for {component}!")
elif component == "tables" and not is_table:
raise MPContribsClientError(f"Use pandas DataFrame for {component}!")
raise MPContribsClientError(
f"Use pandas DataFrame or mpontribs.client.Table for {component}!"
)
elif component == "attachments" and not is_attachment:
raise MPContribsClientError(
f"Use pathlib.Path or mpcontribs.client.Attachment for {component}!"
f"Use str, pathlib.Path or mpcontribs.client.Attachment for {component}"
)

if is_structure:
Expand All @@ -1904,57 +2018,21 @@ def submit_contributions(
logger.warning("storing structure properties not supported, yet!")
del dct["properties"]
elif is_table:
element.fillna('', inplace=True)
element.index = element.index.astype(str)
for col in element.columns:
element[col] = element[col].astype(str)
dct = element.to_dict(orient="split")
table = element if isinstance(element, Table) else Table(element)
table._clean()
dct = table.to_dict(orient="split")
elif is_attachment:
if isinstance(element, Path):
kind = guess(str(element))

if not isinstance(kind, SUPPORTED_FILETYPES):
raise MPContribsClientError(
f"{element.name} not supported. Use one of {SUPPORTED_MIMES}!"
)

content = element.read_bytes()
size = len(content)
if isinstance(element, (str, Path)):
element = Attachment.from_file(element)

if size > MAX_BYTES:
raise MPContribsClientError(
f"{element.name} too large ({size} > {MAX_BYTES})!"
)

dct = {
"mime": kind.mime,
"content": b64encode(content).decode("utf-8")
}
else:
dct = {k: element[k] for k in ["mime", "content"]}
dct = {k: element[k] for k in ["mime", "content"]}

digest = get_md5(dct)

if is_structure:
dct["name"] = getattr(element, "name", None)
if not dct["name"]:
c = element.composition
comp = c.get_integer_formula_and_factor()
dct["name"] = f"{comp[0]}-{idx}" if nelems > 1 else comp[0]
dct["name"] = getattr(element, "name", "structure")
elif is_table:
name = f"table-{idx}" if nelems > 1 else "table"
dct["name"] = element.attrs.get("name", name)
title = element.attrs.get("title", dct["name"])
labels = element.attrs.get("labels", {})
index = element.index.name
variable = element.columns.name

if index and "index" not in labels:
labels["index"] = index
if variable and "variable" not in labels:
labels["variable"] = variable

dct["attrs"] = {"title": title, "labels": labels}
dct["name"], dct["attrs"] = table._attrs_as_dict()
elif is_attachment:
dct["name"] = element.name

Expand All @@ -1977,7 +2055,6 @@ def submit_contributions(

# submit contributions
if contribs:
per_page = self._get_per_page()
total, total_processed = 0, 0

def post_future(track_id, payload):
Expand Down Expand Up @@ -2006,35 +2083,36 @@ def put_future(pk, payload):
retries = 0

while contribs[project_name]:
futures = []
for idx, chunk in enumerate(grouper(per_page, contribs[project_name])):
post_chunk = []
for c in chunk:
if "id" in c:
pk = c.pop("id")
if not c:
logger.error(
f"SKIPPED update of {project_name}/{pk}: empty."
)

payload = ujson.dumps(c).encode("utf-8")
if len(payload) < MAX_PAYLOAD:
futures.append(put_future(pk, payload))
else:
logger.error(
f"SKIPPED update of {project_name}/{pk}: too large."
)
else:
post_chunk.append(c)
futures, post_chunk, idx = [], [], 0

if post_chunk:
payload = ujson.dumps(post_chunk).encode("utf-8")
for n, c in enumerate(contribs[project_name]):
if "id" in c:
pk = c.pop("id")
if not c:
logger.error(f"SKIPPED: update of {project_name}/{pk} empty.")

payload = ujson.dumps(c).encode("utf-8")
if len(payload) < MAX_PAYLOAD:
futures.append(post_future(idx, payload))
futures.append(put_future(pk, payload))
else:
logger.error(
f"SKIPPED {project_name}/{idx}: too large, reduce per_request"
)
logger.error(f"SKIPPED: update of {project_name}/{pk} too large.")
else:
next_payload = ujson.dumps(post_chunk + [c]).encode("utf-8")
if len(next_payload) >= MAX_PAYLOAD:
if post_chunk:
payload = ujson.dumps(post_chunk).encode("utf-8")
futures.append(post_future(idx, payload))
post_chunk.clear()
idx += 1
else:
logger.error(f"SKIPPED: contrib {project_name}/{n} too large.")
continue

post_chunk.append(c)

if post_chunk and len(futures) < ncontribs:
payload = ujson.dumps(post_chunk).encode("utf-8")
futures.append(post_future(idx, payload))

if not futures:
break # nothing to do
Expand All @@ -2059,7 +2137,6 @@ def put_future(pk, payload):
c for c in contribs[project_name]
if c["identifier"] not in existing_ids
]
per_page = int(per_page / 2)
retries += 1
else:
contribs[project_name] = [] # abort retrying
Expand Down

0 comments on commit cb16d34

Please sign in to comment.