Skip to content

Commit

Permalink
NEED CONFIRM: patch numpy random generator
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Sep 11, 2024
1 parent d0c6b5b commit b4f05c5
Showing 1 changed file with 31 additions and 30 deletions.
61 changes: 31 additions & 30 deletions tests/test_crystal_graph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from time import perf_counter
from unittest.mock import patch

import numpy as np
from pymatgen.core import Structure

from chgnet import ROOT
from chgnet.graph import CrystalGraphConverter

np.random.seed(0)

structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3)
converter_legacy = CrystalGraphConverter(
Expand Down Expand Up @@ -127,55 +126,57 @@ def test_crystal_graph_different_cutoff_fast():


def test_crystal_graph_perturb_legacy():
np.random.seed(0)
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)

start = perf_counter()
graph = converter_legacy(structure_perturbed)
print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_perturb_fast():
np.random.seed(0)
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)

start = perf_counter()
graph = converter_fast(structure_perturbed)
print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_isotropic_strained_legacy():
Expand Down

0 comments on commit b4f05c5

Please sign in to comment.