diff --git a/matgl/ext/ase.py b/matgl/ext/ase.py index 2e712d9a..3e541fb9 100644 --- a/matgl/ext/ase.py +++ b/matgl/ext/ase.py @@ -8,9 +8,12 @@ import sys from typing import TYPE_CHECKING -import dgl import numpy as np import torch + +if TYPE_CHECKING: + import dgl + from ase import Atoms, units from ase.calculators.calculator import Calculator, all_changes from ase.constraints import ExpCellFilter @@ -22,7 +25,6 @@ from ase.optimize.lbfgs import LBFGS, LBFGSLineSearch from ase.optimize.mdmin import MDMin from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG -from dgl.backend import tensor from pymatgen.core.structure import Molecule, Structure from pymatgen.io.ase import AseAtomsAdaptor from pymatgen.optimization.neighbors import find_points_in_spheres @@ -79,7 +81,6 @@ def get_graph(self, atoms: Atoms) -> tuple[dgl.DGLGraph, list]: pbc = np.array([1, 1, 1], dtype=int) element_types = self.element_types Z = np.array([np.eye(len(element_types))[element_types.index(i.symbol)] for i in atoms]) - atomic_number = np.array(atoms.get_atomic_numbers()) lattice_matrix = np.ascontiguousarray(np.array(atoms.get_cell()), dtype=float) volume = atoms.get_volume() cart_coords = np.ascontiguousarray(np.array(atoms.get_positions()), dtype=float) @@ -98,16 +99,17 @@ def get_graph(self, atoms: Atoms) -> tuple[dgl.DGLGraph, list]: images[exclude_self], bond_dist[exclude_self], ) - u, v = tensor(src_id), tensor(dst_id) - g = dgl.graph((u, v)) - g.edata["pbc_offset"] = torch.tensor(images) - g.edata["lattice"] = tensor(np.stack([lattice_matrix for _ in range(g.num_edges())])) - g.ndata["attr"] = tensor(Z) - g.ndata["node_type"] = tensor(np.hstack([[element_types.index(i.symbol)] for i in atoms])) - g.ndata["pos"] = tensor(cart_coords) - g.ndata["volume"] = tensor([volume for i in range(atomic_number.shape[0])]) - state_attr = [0.0, 0.0] - g.edata["pbc_offshift"] = torch.matmul(tensor(images), tensor(lattice_matrix)) + g, state_attr = super().get_graph_from_processed_structure( + AseAtomsAdaptor().get_structure(atoms), + src_id, + dst_id, + images, + [lattice_matrix], + Z, + element_types, + cart_coords, + ) + g.ndata["volume"] = torch.tensor([volume] * g.num_nodes()) return g, state_attr diff --git a/matgl/ext/pymatgen.py b/matgl/ext/pymatgen.py index 0cb59b9f..f1dda648 100644 --- a/matgl/ext/pymatgen.py +++ b/matgl/ext/pymatgen.py @@ -1,11 +1,14 @@ """Interface with pymatgen objects.""" from __future__ import annotations -import dgl +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import dgl + import numpy as np import scipy.sparse as sp import torch -from dgl.backend import tensor from pymatgen.core import Element, Molecule, Structure from pymatgen.optimization.neighbors import find_points_in_spheres @@ -62,17 +65,17 @@ def get_graph(self, mol: Molecule) -> tuple[dgl.DGLGraph, list]: nbonds /= natoms adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(natoms, dtype=np.bool_) adj = adj.tocoo() - u, v = tensor(adj.row), tensor(adj.col) - g = dgl.graph((u, v)) - g = dgl.to_bidirected(g) - g.ndata["pos"] = tensor(R) - g.ndata["attr"] = tensor(Z) - g.ndata["node_type"] = tensor(np.hstack([[element_types.index(site.specie.symbol)] for site in mol])) - g.edata["pbc_offset"] = torch.zeros(g.num_edges(), 3) - g.edata["lattice"] = torch.zeros(g.num_edges(), 3, 3) + g, _ = super().get_graph_from_processed_structure( + structure=mol, + src_id=adj.row, + dst_id=adj.col, + images=torch.zeros(len(adj.row), 3), + lattice_matrix=torch.zeros(1, 3, 3), + Z=Z, + element_types=element_types, + cart_coords=R, + ) state_attr = [weight, nbonds] - g.edata["pbc_offshift"] = torch.zeros(g.num_edges(), 3) - return g, state_attr @@ -105,7 +108,6 @@ def get_graph(self, structure: Structure) -> tuple[dgl.DGLGraph, list]: pbc = np.array([1, 1, 1], dtype=int) element_types = self.element_types Z = np.array([np.eye(len(element_types))[element_types.index(site.specie.symbol)] for site in structure]) - atomic_number = np.array([site.specie.Z for site in structure]) lattice_matrix = np.ascontiguousarray(np.array(structure.lattice.matrix), dtype=float) volume = structure.volume cart_coords = np.ascontiguousarray(np.array(structure.cart_coords), dtype=float) @@ -124,13 +126,15 @@ def get_graph(self, structure: Structure) -> tuple[dgl.DGLGraph, list]: images[exclude_self], bond_dist[exclude_self], ) - u, v = tensor(src_id), tensor(dst_id) - g = dgl.graph((u, v)) - g.edata["pbc_offset"] = torch.tensor(images) - g.edata["lattice"] = tensor(np.stack([lattice_matrix for _ in range(g.num_edges())])) - g.ndata["attr"] = tensor(Z) - g.ndata["node_type"] = tensor(np.hstack([[element_types.index(site.specie.symbol)] for site in structure])) - g.ndata["pos"] = tensor(cart_coords) - g.ndata["volume"] = tensor([volume for _ in range(atomic_number.shape[0])]) - state_attr = [0.0, 0.0] + g, state_attr = super().get_graph_from_processed_structure( + structure, + src_id, + dst_id, + images, + [lattice_matrix], + Z, + element_types, + cart_coords, + ) + g.ndata["volume"] = torch.tensor([volume] * g.num_nodes()) return g, state_attr diff --git a/matgl/graph/compute.py b/matgl/graph/compute.py index a01d0cbb..6d828d08 100644 --- a/matgl/graph/compute.py +++ b/matgl/graph/compute.py @@ -60,10 +60,12 @@ def compute_3body(g: dgl.DGLGraph): src_id, dst_id = (triple_bond_indices[:, 0], triple_bond_indices[:, 1]) l_g = dgl.graph((src_id, dst_id)) - l_g.ndata["bond_dist"] = g.edata["bond_dist"] - l_g.ndata["bond_vec"] = g.edata["bond_vec"] - l_g.ndata["pbc_offset"] = g.edata["pbc_offset"] - l_g.ndata["n_triple_ij"] = n_triple_ij + three_body_id = np.unique(triple_bond_indices) + max_three_body_id = max(np.concatenate([three_body_id + 1, [0]])) + l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id] + l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id] + l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id] + l_g.ndata["n_triple_ij"] = n_triple_ij[:max_three_body_id] n_triple_s = torch.tensor(n_triple_s, dtype=torch.int64) # type: ignore return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s @@ -124,16 +126,14 @@ def create_line_graph(g_batched: dgl.DGLGraph, threebody_cutoff: float): g_unbatched = dgl.unbatch(g_batched) l_g_unbatched = [] for g in g_unbatched: - if g.edges()[0].size(dim=0) > 0: - valid_three_body = g.edata["bond_dist"] <= threebody_cutoff - src_id_with_three_body = g.edges()[0][valid_three_body] - dst_id_with_three_body = g.edges()[1][valid_three_body] - graph_with_three_body = dgl.graph((src_id_with_three_body, dst_id_with_three_body)) - graph_with_three_body.edata["bond_dist"] = g.edata["bond_dist"][valid_three_body] - graph_with_three_body.edata["bond_vec"] = g.edata["bond_vec"][valid_three_body] - graph_with_three_body.edata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body] - if graph_with_three_body.edata["bond_dist"].size(dim=0) > 0: - l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(graph_with_three_body) - l_g_unbatched.append(l_g) + valid_three_body = g.edata["bond_dist"] <= threebody_cutoff + src_id_with_three_body = g.edges()[0][valid_three_body] + dst_id_with_three_body = g.edges()[1][valid_three_body] + graph_with_three_body = dgl.graph((src_id_with_three_body, dst_id_with_three_body)) + graph_with_three_body.edata["bond_dist"] = g.edata["bond_dist"][valid_three_body] + graph_with_three_body.edata["bond_vec"] = g.edata["bond_vec"][valid_three_body] + graph_with_three_body.edata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body] + l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(graph_with_three_body) + l_g_unbatched.append(l_g) l_g_batched = dgl.batch(l_g_unbatched) return l_g_batched diff --git a/matgl/graph/converters.py b/matgl/graph/converters.py index 3d26b0de..7a192b20 100644 --- a/matgl/graph/converters.py +++ b/matgl/graph/converters.py @@ -2,10 +2,11 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import dgl +import dgl +import numpy as np +import torch +from dgl.backend import tensor class GraphConverter(metaclass=abc.ABCMeta): @@ -14,8 +15,48 @@ class GraphConverter(metaclass=abc.ABCMeta): @abc.abstractmethod def get_graph(self, structure) -> tuple[dgl.DGLGraph, list]: """Args: - structure: Input crystals or molecule. + structure: Input crystals or molecule. + + Returns: + DGLGraph object, state_attr + """ + + def get_graph_from_processed_structure( + self, + structure, + src_id, + dst_id, + images, + lattice_matrix, + Z, + element_types, + cart_coords, + ) -> tuple[dgl.DGLGraph, list]: + """Construct a dgl graph from processed structure and bond information. + + Args: + structure: Input crystals or molecule of pymatgen structure or molecule types. + src_id: site indices for starting point of bonds. + dst_id: site indices for destination point of bonds. + images: the periodic image offsets for the bonds. + lattice_matrix: lattice information of the structure. + Z: Atomic number information of all atoms in the structure. + element_types: Element symbols of all atoms in the structure. + cart_coords: Cartisian coordinates of all atoms in the structure. Returns: DGLGraph object, state_attr + """ + u, v = tensor(src_id), tensor(dst_id) + g = dgl.graph((u, v)) + n_missing_node = len(structure) - g.num_nodes() # isolated atoms without bonds + g.add_nodes(n_missing_node) + g.edata["pbc_offset"] = torch.tensor(images) + g.edata["lattice"] = tensor(np.repeat(lattice_matrix, g.num_edges(), axis=0)) + g.ndata["attr"] = tensor(Z) + g.ndata["node_type"] = tensor(np.hstack([[element_types.index(site.specie.symbol)] for site in structure])) + g.ndata["pos"] = tensor(cart_coords) + state_attr = [0.0, 0.0] + g.edata["pbc_offshift"] = torch.matmul(tensor(images), tensor(lattice_matrix[0])) + return g, state_attr diff --git a/matgl/graph/data.py b/matgl/graph/data.py index 49616d7a..9fca9c3f 100644 --- a/matgl/graph/data.py +++ b/matgl/graph/data.py @@ -320,8 +320,8 @@ def __getitem__(self, idx: int): self.line_graphs[idx], self.state_attr[idx], self.energies[idx], - torch.tensor(self.forces[idx]), - torch.tensor(self.stresses[idx]), # type: ignore + torch.tensor(self.forces[idx]).float(), + torch.tensor(self.stresses[idx]).float(), # type: ignore ) def __len__(self): diff --git a/matgl/layers/_basis.py b/matgl/layers/_basis.py index e5d60d0d..7573ed33 100644 --- a/matgl/layers/_basis.py +++ b/matgl/layers/_basis.py @@ -60,12 +60,11 @@ class SphericalBesselFunction: """Calculate the spherical Bessel function based on sympy + pytorch implementations.""" def __init__(self, max_l: int, max_n: int = 5, cutoff: float = 5.0, smooth: bool = False): - """ - Args: - max_l: int, max order (excluding l) - max_n: int, max number of roots used in each l - cutoff: float, cutoff radius - smooth: Whether to smooth the function. + """Args: + max_l: int, max order (excluding l) + max_n: int, max number of roots used in each l + cutoff: float, cutoff radius + smooth: Whether to smooth the function. """ self.max_l = max_l self.max_n = max_n @@ -108,6 +107,8 @@ def _call_smooth_sbf(self, r): return torch.t(torch.stack(results)) def _call_sbf(self, r): + r_c = r.clone() + r_c[r_c > self.cutoff] = self.cutoff roots = SPHERICAL_BESSEL_ROOTS[: self.max_l, : self.max_n] results = [] @@ -117,7 +118,7 @@ def _call_sbf(self, r): func = self.funcs[i] func_add1 = self.funcs[i + 1] results.append( - func(r[:, None] * root[None, :] / self.cutoff) * factor / torch.abs(func_add1(root[None, :])) + func(r_c[:, None] * root[None, :] / self.cutoff) * factor / torch.abs(func_add1(root[None, :])) ) return torch.cat(results, axis=1) diff --git a/matgl/layers/_three_body.py b/matgl/layers/_three_body.py index 2d877ffb..23bf7d35 100644 --- a/matgl/layers/_three_body.py +++ b/matgl/layers/_three_body.py @@ -8,11 +8,7 @@ import torch from torch import nn -from matgl.utils.maths import ( - _block_repeat, - get_segment_indices_from_n, - scatter_sum, -) +from matgl.utils.maths import _block_repeat, get_segment_indices_from_n, scatter_sum if TYPE_CHECKING: import dgl @@ -70,6 +66,8 @@ def forward( num_segments=graph.num_edges(), dim=0, ) + if not new_bonds.data.shape[0]: + return edge_feat edge_feat_updated = edge_feat + self.update_network_bond(new_bonds) return edge_feat_updated diff --git a/matgl/models/_m3gnet.py b/matgl/models/_m3gnet.py index 9a0285ed..e9b6cbcd 100644 --- a/matgl/models/_m3gnet.py +++ b/matgl/models/_m3gnet.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING import dgl +import numpy as np import torch from torch import nn @@ -149,13 +150,19 @@ def __init__( ) self.basis_expansion = SphericalBesselWithHarmonics( - max_n=max_n, max_l=max_l, cutoff=cutoff, use_phi=use_phi, use_smooth=use_smooth + max_n=max_n, + max_l=max_l, + cutoff=cutoff, + use_phi=use_phi, + use_smooth=use_smooth, ) self.three_body_interactions = nn.ModuleList( { ThreeBodyInteractions( update_network_atom=MLP( - dims=[dim_node_embedding, degree], activation=nn.Sigmoid(), activate_last=True + dims=[dim_node_embedding, degree], + activation=nn.Sigmoid(), + activate_last=True, ), update_network_bond=GatedMLP(in_feats=degree, dims=[dim_edge_embedding], use_bias=False), ) @@ -208,7 +215,12 @@ def __init__( self.task_type = task_type self.is_intensive = is_intensive - def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: dgl.DGLGraph | None = None): + def forward( + self, + g: dgl.DGLGraph, + state_attr: torch.Tensor | None = None, + l_g: dgl.DGLGraph | None = None, + ): """Performs message passing and updates node representations. Args: @@ -228,10 +240,11 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: if l_g is None: l_g = create_line_graph(g, self.threebody_cutoff) else: - valid_three_body = g.edata["bond_dist"] <= self.threebody_cutoff - l_g.ndata["bond_vec"] = g.edata["bond_vec"][valid_three_body] - l_g.ndata["bond_dist"] = g.edata["bond_dist"][valid_three_body] - l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body] + three_body_id = np.unique(np.concatenate(l_g.edges())) + max_three_body_id = max(np.concatenate([three_body_id + 1, [0]])) + l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id] + l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id] + l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id] l_g.apply_edges(compute_theta_and_phi) g.edata["rbf"] = expanded_dists three_body_basis = self.basis_expansion(l_g) @@ -239,7 +252,12 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: num_node_feats, num_edge_feats, num_state_feats = self.embedding(node_types, g.edata["rbf"], state_attr) for i in range(self.n_blocks): num_edge_feats = self.three_body_interactions[i]( - g, l_g, three_body_basis, three_body_cutoff, num_node_feats, num_edge_feats + g, + l_g, + three_body_basis, + three_body_cutoff, + num_node_feats, + num_edge_feats, ) num_edge_feats, num_node_feats, num_state_feats = self.graph_layers[i]( g, num_edge_feats, num_node_feats, num_state_feats @@ -258,7 +276,10 @@ def forward(self, g: dgl.DGLGraph, state_attr: torch.Tensor | None = None, l_g: return torch.squeeze(output) def predict_structure( - self, structure, state_feats: torch.Tensor | None = None, graph_converter: GraphConverter | None = None + self, + structure, + state_feats: torch.Tensor | None = None, + graph_converter: GraphConverter | None = None, ): """Convenience method to directly predict property from structure. @@ -274,7 +295,7 @@ def predict_structure( from matgl.ext.pymatgen import Structure2Graph graph_converter = Structure2Graph(element_types=self.element_types, cutoff=self.cutoff) # type: ignore - g, stare_feats_default = graph_converter.get_graph(structure) + g, state_feats_default = graph_converter.get_graph(structure) if state_feats is None: - state_feats = torch.tensor(stare_feats_default) + state_feats = torch.tensor(state_feats_default) return self(g=g, state_attr=state_feats).detach() diff --git a/tests/apps/test_pes.py b/tests/apps/test_pes.py index 6c2a36e6..57d1d5fd 100644 --- a/tests/apps/test_pes.py +++ b/tests/apps/test_pes.py @@ -5,7 +5,8 @@ from matgl.apps.pes import Potential from matgl.models._m3gnet import M3GNet - +from pymatgen.core import Structure, Lattice +from matgl.ext.pymatgen import Structure2Graph, get_element_list @pytest.fixture() def model(): @@ -48,3 +49,27 @@ def test_potential_e(self, graph_MoS, model): assert [f.size(dim=0)] == [1] assert [s.size(dim=0)] == [1] assert [h.size(dim=0)] == [1] + + def test_potential_two_body(self, model): + structure = Structure(Lattice.cubic(10.0), ["Mo", "Mo"], [[0.0, 0, 0], [0.2, 0.0, 0.0]]) + element_types = get_element_list([structure]) + p2g = Structure2Graph(element_types=element_types, cutoff=5.0) + graph, state = p2g.get_graph(structure) + ff = Potential(model=model, calc_hessian=True) + e, f, s, h = ff(graph, torch.tensor(state)) + assert [torch.numel(e)] == [1] + assert [f.size(dim=0), f.size(dim=1)] == [2, 3] + assert [s.size(dim=0), s.size(dim=1)] == [3, 3] + assert [h.size(dim=0), h.size(dim=1)] == [6, 6] + + def test_potential_isolated_atom(self, model): + structure = Structure(Lattice.cubic(10.0), ["Mo"], [[0.0, 0, 0]]) + element_types = get_element_list([structure]) + p2g = Structure2Graph(element_types=element_types, cutoff=5.0) + graph, state = p2g.get_graph(structure) + ff = Potential(model=model, calc_hessian=True) + e, f, s, h = ff(graph, torch.tensor(state)) + assert [torch.numel(e)] == [1] + assert [f.size(dim=0), f.size(dim=1)] == [1, 3] + assert [s.size(dim=0), s.size(dim=1)] == [3, 3] + assert [h.size(dim=0), h.size(dim=1)] == [3, 3] \ No newline at end of file diff --git a/tests/utils/test_cutoff.py b/tests/utils/test_cutoff.py index d688419f..356971d4 100644 --- a/tests/utils/test_cutoff.py +++ b/tests/utils/test_cutoff.py @@ -14,7 +14,8 @@ def test_cosine(): assert_close( three_cutoff, torch.tensor([0.8536, 0.7270, 0.5782, 0.4218, 0.2730, 0.1464, 0.0545, 0.0062, 0.0000, 0.0000, 0.0000]), - atol=1e-4, rtol=0.0 + atol=1e-4, + rtol=0.0, ) @@ -25,10 +26,9 @@ def test_polymonial_cutoff(): envelope = polynomial_cutoff(r, three_body_cutoff) assert_close( envelope, - torch.tensor( - [0.8965, 0.7648, 0.5931, 0.4069, 0.2352, 0.1035, 0.0266, 0.0012, 0.0000, 0.0000, 0.0000] - ), - atol=1e-4, rtol=0.0 + torch.tensor([0.8965, 0.7648, 0.5931, 0.4069, 0.2352, 0.1035, 0.0266, 0.0012, 0.0000, 0.0000, 0.0000]), + atol=1e-4, + rtol=0.0, ) # test behaviour smoothing a SBF with cutoff @@ -49,4 +49,4 @@ def test_polymonial_cutoff(): envelope_res.backward(torch.ones_like(envelope_res), retain_graph=True) assert r.grad[-1] == 0.0 envelope_res.backward(torch.ones_like(envelope_res)) - assert r.grad[-1] == 0.0 \ No newline at end of file + assert r.grad[-1] == 0.0