From 952eb945900aa96da39ffeea13279e56fa70ed16 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Sun, 20 Oct 2024 11:26:49 +0800 Subject: [PATCH] also lazily import numpy --- src/monty/json.py | 88 ++++++++++++++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/src/monty/json.py b/src/monty/json.py index 25ae42a5..e3402c45 100644 --- a/src/monty/json.py +++ b/src/monty/json.py @@ -21,11 +21,6 @@ from typing import Any from uuid import UUID, uuid4 -try: - import numpy as np -except ImportError: - np = None - try: import pydantic except ImportError: @@ -594,7 +589,9 @@ def default(self, o) -> dict: d["data"] = o.numpy().tolist() return d - if np is not None: + try: + import numpy as np + if isinstance(o, np.ndarray): if str(o.dtype).startswith("complex"): return { @@ -611,6 +608,8 @@ def default(self, o) -> dict: } if isinstance(o, np.generic): return o.item() + except ImportError: + pass if _check_type(o, "pandas.core.frame.DataFrame"): return { @@ -800,6 +799,7 @@ def process_decoded(self, d): elif modname == "torch" and classname == "Tensor": try: + import numpy as np import torch if "Complex" in d["dtype"]: @@ -814,16 +814,22 @@ def process_decoded(self, d): except ImportError: pass - elif np is not None and modname == "numpy" and classname == "array": - if d["dtype"].startswith("complex"): - return np.array( - [ - np.array(r) + np.array(i) * 1j - for r, i in zip(*d["data"]) - ], - dtype=d["dtype"], - ) - return np.array(d["data"], dtype=d["dtype"]) + elif modname == "numpy" and classname == "array": + try: + import numpy as np + + if d["dtype"].startswith("complex"): + return np.array( + [ + np.array(r) + np.array(i) * 1j + for r, i in zip(*d["data"]) + ], + dtype=d["dtype"], + ) + return np.array(d["data"], dtype=d["dtype"]) + + except ImportError: + pass elif modname == "pandas": import pandas as pd @@ -926,6 +932,7 @@ def jsanitize( or (bson is not None and isinstance(obj, bson.objectid.ObjectId)) ): return obj + if isinstance(obj, (list, tuple)): return [ jsanitize( @@ -937,22 +944,35 @@ def jsanitize( ) for i in obj ] - if np is not None and isinstance(obj, np.ndarray): - try: - return [ - jsanitize( - i, - strict=strict, - allow_bson=allow_bson, - enum_values=enum_values, - recursive_msonable=recursive_msonable, - ) - for i in obj.tolist() - ] - except TypeError: - return obj.tolist() - if np is not None and isinstance(obj, np.generic): - return obj.item() + + try: + import numpy as np + + if isinstance(obj, np.ndarray): + try: + return [ + jsanitize( + i, + strict=strict, + allow_bson=allow_bson, + enum_values=enum_values, + recursive_msonable=recursive_msonable, + ) + for i in obj.tolist() + ] + except TypeError: + return obj.tolist() + except ImportError: + pass + + try: + import numpy as np + + if isinstance(obj, np.generic): + return obj.item() + except ImportError: + pass + if _check_type( obj, ( @@ -962,6 +982,7 @@ def jsanitize( ), ): return obj.to_dict() + if isinstance(obj, dict): return { str(k): jsanitize( @@ -973,10 +994,13 @@ def jsanitize( ) for k, v in obj.items() } + if isinstance(obj, (int, float)): return obj + if obj is None: return None + if isinstance(obj, (pathlib.Path, datetime.datetime)): return str(obj)