-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ruff
fixes
#184
ruff
fixes
#184
Changes from all commits
e02ec19
d581717
110f4eb
9996c5d
1c74710
6fa3755
3023f84
6a4246d
62f3638
cc1f833
d628068
d0c6b5b
b4f05c5
7f7d6b6
c70f804
1d0b153
4f044eb
45a1fd7
b8f6cfa
e4ceb6a
d00a425
af86ce6
d7aaded
47a77e0
570bfed
ff8bf84
1f40088
b60f5e4
e53337e
1b84de4
59787a6
bda704b
4eed205
1d1ec65
4cafef7
d1bf176
107fb02
190d6e8
c9e8e5f
8875487
8c3ffa2
aeb9053
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,14 @@ | ||
from __future__ import annotations | ||
|
||
from time import perf_counter | ||
from unittest.mock import patch | ||
|
||
import numpy as np | ||
from pymatgen.core import Structure | ||
|
||
from chgnet import ROOT | ||
from chgnet.graph import CrystalGraphConverter | ||
|
||
np.random.seed(0) | ||
|
||
structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif") | ||
converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3) | ||
converter_legacy = CrystalGraphConverter( | ||
|
@@ -127,55 +126,57 @@ def test_crystal_graph_different_cutoff_fast(): | |
|
||
|
||
def test_crystal_graph_perturb_legacy(): | ||
np.random.seed(0) | ||
DanielYang59 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
structure_perturbed = structure.copy() | ||
structure_perturbed.perturb(distance=0.1) | ||
fixed_rng = np.random.default_rng(0) | ||
with patch("numpy.random.default_rng", return_value=fixed_rng): | ||
structure_perturbed.perturb(distance=0.1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should probably add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point, especially with the generator implementation. I didn't realize this downside before but it turns out with the new generator implementation every |
||
|
||
start = perf_counter() | ||
graph = converter_legacy(structure_perturbed) | ||
print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 | ||
|
||
assert list(graph.atom_frac_coord.shape) == [8, 3] | ||
assert list(graph.atom_graph.shape) == [410, 2] | ||
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53 | ||
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53 | ||
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 | ||
|
||
assert list(graph.bond_graph.shape) == [688, 5] | ||
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 | ||
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17 | ||
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17 | ||
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10 | ||
assert list(graph.atom_graph.shape) == [420, 2] | ||
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54 | ||
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54 | ||
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54 | ||
|
||
assert list(graph.bond_graph.shape) == [850, 5] | ||
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156 | ||
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18 | ||
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18 | ||
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0 | ||
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0 | ||
assert list(graph.lattice.shape) == [3, 3] | ||
assert list(graph.undirected2directed.shape) == [205] | ||
assert list(graph.directed2undirected.shape) == [410] | ||
assert list(graph.undirected2directed.shape) == [210] | ||
assert list(graph.directed2undirected.shape) == [420] | ||
|
||
|
||
def test_crystal_graph_perturb_fast(): | ||
np.random.seed(0) | ||
structure_perturbed = structure.copy() | ||
structure_perturbed.perturb(distance=0.1) | ||
fixed_rng = np.random.default_rng(0) | ||
with patch("numpy.random.default_rng", return_value=fixed_rng): | ||
structure_perturbed.perturb(distance=0.1) | ||
|
||
start = perf_counter() | ||
graph = converter_fast(structure_perturbed) | ||
print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201 | ||
|
||
assert list(graph.atom_frac_coord.shape) == [8, 3] | ||
assert list(graph.atom_graph.shape) == [410, 2] | ||
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53 | ||
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53 | ||
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50 | ||
|
||
assert list(graph.bond_graph.shape) == [688, 5] | ||
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90 | ||
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17 | ||
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17 | ||
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10 | ||
assert list(graph.atom_graph.shape) == [420, 2] | ||
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54 | ||
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54 | ||
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54 | ||
|
||
assert list(graph.bond_graph.shape) == [850, 5] | ||
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156 | ||
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18 | ||
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18 | ||
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0 | ||
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0 | ||
assert list(graph.lattice.shape) == [3, 3] | ||
assert list(graph.undirected2directed.shape) == [205] | ||
assert list(graph.directed2undirected.shape) == [410] | ||
assert list(graph.undirected2directed.shape) == [210] | ||
assert list(graph.directed2undirected.shape) == [420] | ||
|
||
|
||
def test_crystal_graph_isotropic_strained_legacy(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is intended, but I don't think
SystemExit
should be raised here? ff8bf84