Skip to content

Commit

Permalink
lazily import torch for MontyDecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Oct 20, 2024
1 parent a7752ee commit 196dcf9
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions src/monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@
orjson = None


try:
import torch
except ImportError:
torch = None

__version__ = "3.0.0"


Expand Down Expand Up @@ -803,15 +798,21 @@ def process_decoded(self, d):
d = {k: self.process_decoded(v) for k, v in data.items()}
return cls_(**d)

elif torch is not None and modname == "torch" and classname == "Tensor":
if "Complex" in d["dtype"]:
return torch.tensor( # pylint: disable=E1101
[
np.array(r) + np.array(i) * 1j
for r, i in zip(*d["data"])
],
).type(d["dtype"])
return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101
elif modname == "torch" and classname == "Tensor":
try:
import torch

if "Complex" in d["dtype"]:
return torch.tensor( # pylint: disable=E1101
[
np.array(r) + np.array(i) * 1j
for r, i in zip(*d["data"])
],
).type(d["dtype"])
return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101

except ImportError:
pass

elif np is not None and modname == "numpy" and classname == "array":
if d["dtype"].startswith("complex"):
Expand Down

0 comments on commit 196dcf9

Please sign in to comment.