diff --git a/tests/test_crystal_graph.py b/tests/test_crystal_graph.py index 8834d235..4022e704 100644 --- a/tests/test_crystal_graph.py +++ b/tests/test_crystal_graph.py @@ -1,6 +1,7 @@ from __future__ import annotations from time import perf_counter +from unittest.mock import patch import numpy as np from pymatgen.core import Structure @@ -8,8 +9,6 @@ from chgnet import ROOT from chgnet.graph import CrystalGraphConverter -np.random.seed(0) - structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif") converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3) converter_legacy = CrystalGraphConverter( @@ -127,55 +126,57 @@ def test_crystal_graph_different_cutoff_fast(): def test_crystal_graph_perturb_legacy(): - np.random.seed(0) structure_perturbed = structure.copy() - structure_perturbed.perturb(distance=0.1) + fixed_rng = np.random.default_rng(0) + with patch("numpy.random.default_rng", return_value=fixed_rng): + structure_perturbed.perturb(distance=0.1) start = perf_counter() graph = converter_legacy(structure_perturbed) print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] - assert list(graph.atom_graph.shape) == [410, 2] - assert (graph.atom_graph[:, 0] == 3).sum().item() == 53 - assert (graph.atom_graph[:, 1] == 3).sum().item() == 53 - assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 - - assert list(graph.bond_graph.shape) == [688, 5] - assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 - assert (graph.bond_graph[:, 1] == 36).sum().item() == 17 - assert (graph.bond_graph[:, 3] == 36).sum().item() == 17 - assert (graph.bond_graph[:, 2] == 306).sum().item() == 10 + assert list(graph.atom_graph.shape) == [420, 2] + assert (graph.atom_graph[:, 0] == 3).sum().item() == 54 + assert (graph.atom_graph[:, 1] == 3).sum().item() == 54 + assert (graph.atom_graph[:, 1] == 6).sum().item() == 54 + + assert list(graph.bond_graph.shape) == [850, 5] + assert (graph.bond_graph[:, 0] == 1).sum().item() == 156 + assert (graph.bond_graph[:, 1] == 36).sum().item() == 18 + assert (graph.bond_graph[:, 3] == 36).sum().item() == 18 + assert (graph.bond_graph[:, 2] == 306).sum().item() == 0 assert (graph.bond_graph[:, 4] == 120).sum().item() == 0 assert list(graph.lattice.shape) == [3, 3] - assert list(graph.undirected2directed.shape) == [205] - assert list(graph.directed2undirected.shape) == [410] + assert list(graph.undirected2directed.shape) == [210] + assert list(graph.directed2undirected.shape) == [420] def test_crystal_graph_perturb_fast(): - np.random.seed(0) structure_perturbed = structure.copy() - structure_perturbed.perturb(distance=0.1) + fixed_rng = np.random.default_rng(0) + with patch("numpy.random.default_rng", return_value=fixed_rng): + structure_perturbed.perturb(distance=0.1) start = perf_counter() graph = converter_fast(structure_perturbed) print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 assert list(graph.atom_frac_coord.shape) == [8, 3] - assert list(graph.atom_graph.shape) == [410, 2] - assert (graph.atom_graph[:, 0] == 3).sum().item() == 53 - assert (graph.atom_graph[:, 1] == 3).sum().item() == 53 - assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 - - assert list(graph.bond_graph.shape) == [688, 5] - assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 - assert (graph.bond_graph[:, 1] == 36).sum().item() == 17 - assert (graph.bond_graph[:, 3] == 36).sum().item() == 17 - assert (graph.bond_graph[:, 2] == 306).sum().item() == 10 + assert list(graph.atom_graph.shape) == [420, 2] + assert (graph.atom_graph[:, 0] == 3).sum().item() == 54 + assert (graph.atom_graph[:, 1] == 3).sum().item() == 54 + assert (graph.atom_graph[:, 1] == 6).sum().item() == 54 + + assert list(graph.bond_graph.shape) == [850, 5] + assert (graph.bond_graph[:, 0] == 1).sum().item() == 156 + assert (graph.bond_graph[:, 1] == 36).sum().item() == 18 + assert (graph.bond_graph[:, 3] == 36).sum().item() == 18 + assert (graph.bond_graph[:, 2] == 306).sum().item() == 0 assert (graph.bond_graph[:, 4] == 120).sum().item() == 0 assert list(graph.lattice.shape) == [3, 3] - assert list(graph.undirected2directed.shape) == [205] - assert list(graph.directed2undirected.shape) == [410] + assert list(graph.undirected2directed.shape) == [210] + assert list(graph.directed2undirected.shape) == [420] def test_crystal_graph_isotropic_strained_legacy():