Skip to content

Commit

Permalink
also lazily import numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Oct 20, 2024
1 parent 196dcf9 commit 952eb94
Showing 1 changed file with 56 additions and 32 deletions.
88 changes: 56 additions & 32 deletions src/monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@
from typing import Any
from uuid import UUID, uuid4

try:
import numpy as np
except ImportError:
np = None

try:
import pydantic
except ImportError:
Expand Down Expand Up @@ -594,7 +589,9 @@ def default(self, o) -> dict:
d["data"] = o.numpy().tolist()
return d

if np is not None:
try:
import numpy as np

if isinstance(o, np.ndarray):
if str(o.dtype).startswith("complex"):
return {
Expand All @@ -611,6 +608,8 @@ def default(self, o) -> dict:
}
if isinstance(o, np.generic):
return o.item()
except ImportError:
pass

if _check_type(o, "pandas.core.frame.DataFrame"):
return {
Expand Down Expand Up @@ -800,6 +799,7 @@ def process_decoded(self, d):

elif modname == "torch" and classname == "Tensor":
try:
import numpy as np
import torch

if "Complex" in d["dtype"]:
Expand All @@ -814,16 +814,22 @@ def process_decoded(self, d):
except ImportError:
pass

elif np is not None and modname == "numpy" and classname == "array":
if d["dtype"].startswith("complex"):
return np.array(
[
np.array(r) + np.array(i) * 1j
for r, i in zip(*d["data"])
],
dtype=d["dtype"],
)
return np.array(d["data"], dtype=d["dtype"])
elif modname == "numpy" and classname == "array":
try:
import numpy as np

if d["dtype"].startswith("complex"):
return np.array(
[
np.array(r) + np.array(i) * 1j
for r, i in zip(*d["data"])
],
dtype=d["dtype"],
)
return np.array(d["data"], dtype=d["dtype"])

except ImportError:
pass

elif modname == "pandas":
import pandas as pd
Expand Down Expand Up @@ -926,6 +932,7 @@ def jsanitize(
or (bson is not None and isinstance(obj, bson.objectid.ObjectId))
):
return obj

if isinstance(obj, (list, tuple)):
return [
jsanitize(
Expand All @@ -937,22 +944,35 @@ def jsanitize(
)
for i in obj
]
if np is not None and isinstance(obj, np.ndarray):
try:
return [
jsanitize(
i,
strict=strict,
allow_bson=allow_bson,
enum_values=enum_values,
recursive_msonable=recursive_msonable,
)
for i in obj.tolist()
]
except TypeError:
return obj.tolist()
if np is not None and isinstance(obj, np.generic):
return obj.item()

try:
import numpy as np

if isinstance(obj, np.ndarray):
try:
return [
jsanitize(
i,
strict=strict,
allow_bson=allow_bson,
enum_values=enum_values,
recursive_msonable=recursive_msonable,
)
for i in obj.tolist()
]
except TypeError:
return obj.tolist()
except ImportError:
pass

try:
import numpy as np

if isinstance(obj, np.generic):
return obj.item()
except ImportError:
pass

if _check_type(
obj,
(
Expand All @@ -962,6 +982,7 @@ def jsanitize(
),
):
return obj.to_dict()

if isinstance(obj, dict):
return {
str(k): jsanitize(
Expand All @@ -973,10 +994,13 @@ def jsanitize(
)
for k, v in obj.items()
}

if isinstance(obj, (int, float)):
return obj

if obj is None:
return None

if isinstance(obj, (pathlib.Path, datetime.datetime)):
return str(obj)

Expand Down

0 comments on commit 952eb94

Please sign in to comment.