Skip to content

Commit

Permalink
Merge pull request #713 from DanielYang59/lazy-import-torch
Browse files Browse the repository at this point in the history
Lazily import `torch/pydantic` in `json` module, speedup `from monty.json import` by 10x
  • Loading branch information
shyuep authored Oct 21, 2024
2 parents 189b6e6 + f33fd5a commit 1ff96e8
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: "3.x"

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
max-parallel: 20
matrix:
os: [ubuntu-latest, macos-14, windows-latest]
python-version: ["3.9", "3.x"]
python-version: ["3.9", "3.12"]

runs-on: ${{ matrix.os }}

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
stages: [commit, commit-msg]
stages: [pre-commit, commit-msg]
exclude_types: [html]
additional_dependencies: [tomli] # needed to read pyproject.toml below py3.11

Expand Down
93 changes: 52 additions & 41 deletions src/monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,14 @@
from importlib import import_module
from inspect import getfullargspec
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING
from uuid import UUID, uuid4

import numpy as np
from ruamel.yaml import YAML

try:
import pydantic
except ImportError:
pydantic = None

try:
from pydantic_core import core_schema
except ImportError:
core_schema = None
if TYPE_CHECKING:
from typing import Any

try:
import bson
Expand All @@ -44,15 +37,11 @@
except ImportError:
orjson = None

try:
import torch
except ImportError:
torch = None

__version__ = "3.0.0"


def _load_redirect(redirect_file):
def _load_redirect(redirect_file) -> dict:
try:
with open(redirect_file) as f:
yaml = YAML()
Expand All @@ -63,7 +52,7 @@ def _load_redirect(redirect_file):
return {}

# Convert the full paths to module/class
redirect_dict = defaultdict(dict)
redirect_dict: dict = defaultdict(dict)
for old_path, new_path in d.items():
old_class = old_path.split(".")[-1]
old_module = ".".join(old_path.split(".")[:-1])
Expand All @@ -79,7 +68,7 @@ def _load_redirect(redirect_file):
return dict(redirect_dict)


def _check_type(obj, type_str) -> bool:
def _check_type(obj, type_str: tuple[str, ...] | str) -> bool:
"""Alternative to isinstance that avoids imports.
Checks whether obj is an instance of the type defined by type_str. This
Expand Down Expand Up @@ -113,7 +102,7 @@ class B(A):
mro = type(obj).mro()
except TypeError:
return False
return any(o.__module__ + "." + o.__name__ == ts for o in mro for ts in type_str)
return any(f"{o.__module__}.{o.__name__}" == ts for o in mro for ts in type_str)


class MSONable:
Expand Down Expand Up @@ -330,8 +319,11 @@ def __get_pydantic_core_schema__(cls, source_type, handler):
"""
pydantic v2 core schema definition
"""
if core_schema is None:
raise RuntimeError("Pydantic >= 2.0 is required for validation")
try:
from pydantic_core import core_schema

except ImportError as exc:
raise RuntimeError("Pydantic >= 2.0 is required for validation") from exc

s = core_schema.with_info_plain_validator_function(cls.validate_monty_v2)

Expand Down Expand Up @@ -533,7 +525,7 @@ def _recursive_name_object_map_replacement(d, name_object_map):
class MontyEncoder(json.JSONEncoder):
"""
A Json Encoder which supports the MSONable API, plus adds support for
numpy arrays, datetime objects, bson ObjectIds (requires bson).
NumPy arrays, datetime objects, bson ObjectIds (requires bson).
Usage::
# Add it as a *cls* keyword when using json.dump
json.dumps(object, cls=MontyEncoder)
Expand Down Expand Up @@ -578,8 +570,8 @@ def default(self, o) -> dict:
if isinstance(o, Path):
return {"@module": "pathlib", "@class": "Path", "string": str(o)}

if torch is not None and isinstance(o, torch.Tensor):
# Support for Pytorch Tensors.
# Support for Pytorch Tensors
if _check_type(o, "torch.Tensor"):
d: dict[str, Any] = {
"@module": "torch",
"@class": "Tensor",
Expand All @@ -605,6 +597,7 @@ def default(self, o) -> dict:
"dtype": str(o.dtype),
"data": o.tolist(),
}

if isinstance(o, np.generic):
return o.item()

Expand Down Expand Up @@ -651,7 +644,7 @@ def default(self, o) -> dict:
raise AttributeError(e)

try:
if pydantic is not None and isinstance(o, pydantic.BaseModel):
if _check_type(o, "pydantic.main.BaseModel"):
d = o.model_dump()
elif (
dataclasses is not None
Expand Down Expand Up @@ -781,11 +774,18 @@ def process_decoded(self, d):
return cls_.from_dict(data)
if issubclass(cls_, Enum):
return cls_(d["value"])
if pydantic is not None and issubclass(
cls_, pydantic.BaseModel
): # pylint: disable=E1101
d = {k: self.process_decoded(v) for k, v in data.items()}
return cls_(**d)

try:
import pydantic

if issubclass(cls_, pydantic.BaseModel):
d = {
k: self.process_decoded(v) for k, v in data.items()
}
return cls_(**d)
except ImportError:
pass

if (
dataclasses is not None
and (not issubclass(cls_, MSONable))
Expand All @@ -794,15 +794,21 @@ def process_decoded(self, d):
d = {k: self.process_decoded(v) for k, v in data.items()}
return cls_(**d)

elif torch is not None and modname == "torch" and classname == "Tensor":
if "Complex" in d["dtype"]:
return torch.tensor( # pylint: disable=E1101
[
np.array(r) + np.array(i) * 1j
for r, i in zip(*d["data"])
],
).type(d["dtype"])
return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101
elif modname == "torch" and classname == "Tensor":
try:
import torch # import torch is very expensive

if "Complex" in d["dtype"]:
return torch.tensor(
[
np.array(r) + np.array(i) * 1j
for r, i in zip(*d["data"])
],
).type(d["dtype"])
return torch.tensor(d["data"]).type(d["dtype"])

except ImportError:
pass

elif modname == "numpy" and classname == "array":
if d["dtype"].startswith("complex"):
Expand Down Expand Up @@ -858,8 +864,8 @@ def decode(self, s):
"""
if orjson is not None:
try:
d = orjson.loads(s) # pylint: disable=E1101
except orjson.JSONDecodeError: # pylint: disable=E1101
d = orjson.loads(s)
except orjson.JSONDecodeError:
d = json.loads(s)
else:
d = json.loads(s)
Expand Down Expand Up @@ -916,6 +922,7 @@ def jsanitize(
or (bson is not None and isinstance(obj, bson.objectid.ObjectId))
):
return obj

if isinstance(obj, (list, tuple)):
return [
jsanitize(
Expand Down Expand Up @@ -955,6 +962,7 @@ def jsanitize(
),
):
return obj.to_dict()

if isinstance(obj, dict):
return {
str(k): jsanitize(
Expand All @@ -966,10 +974,13 @@ def jsanitize(
)
for k, v in obj.items()
}

if isinstance(obj, (int, float)):
return obj

if obj is None:
return None

if isinstance(obj, (pathlib.Path, datetime.datetime)):
return str(obj)

Expand All @@ -991,7 +1002,7 @@ def jsanitize(
if isinstance(obj, str):
return obj

if pydantic is not None and isinstance(obj, pydantic.BaseModel): # pylint: disable=E1101
if _check_type(obj, "pydantic.main.BaseModel"):
return jsanitize(
MontyEncoder().default(obj),
strict=strict,
Expand Down

0 comments on commit 1ff96e8

Please sign in to comment.