From 196dcf9eabcc1c870c5c93b82a5eab0d039b35b7 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Sun, 20 Oct 2024 11:02:11 +0800 Subject: [PATCH] lazily import torch for MontyDecoder --- src/monty/json.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/monty/json.py b/src/monty/json.py index e05aab0b..25ae42a5 100644 --- a/src/monty/json.py +++ b/src/monty/json.py @@ -52,11 +52,6 @@ orjson = None -try: - import torch -except ImportError: - torch = None - __version__ = "3.0.0" @@ -803,15 +798,21 @@ def process_decoded(self, d): d = {k: self.process_decoded(v) for k, v in data.items()} return cls_(**d) - elif torch is not None and modname == "torch" and classname == "Tensor": - if "Complex" in d["dtype"]: - return torch.tensor( # pylint: disable=E1101 - [ - np.array(r) + np.array(i) * 1j - for r, i in zip(*d["data"]) - ], - ).type(d["dtype"]) - return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101 + elif modname == "torch" and classname == "Tensor": + try: + import torch + + if "Complex" in d["dtype"]: + return torch.tensor( # pylint: disable=E1101 + [ + np.array(r) + np.array(i) * 1j + for r, i in zip(*d["data"]) + ], + ).type(d["dtype"]) + return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101 + + except ImportError: + pass elif np is not None and modname == "numpy" and classname == "array": if d["dtype"].startswith("complex"):