Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow structures without three-body interaction in cutoff radius -- isolated atoms or two-body regions #85

Closed
46 changes: 29 additions & 17 deletions matgl/ext/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def get_graph(self, atoms: Atoms) -> tuple[dgl.DGLGraph, list]:
numerical_tol = 1.0e-8
pbc = np.array([1, 1, 1], dtype=int)
element_types = self.element_types
Z = np.array([np.eye(len(element_types))[element_types.index(i.symbol)] for i in atoms])
Z = np.array(
[np.eye(len(element_types))[element_types.index(i.symbol)] for i in atoms]
)
atomic_number = np.array(atoms.get_atomic_numbers())
lattice_matrix = np.ascontiguousarray(np.array(atoms.get_cell()), dtype=float)
volume = atoms.get_volume()
Expand All @@ -98,17 +100,17 @@ def get_graph(self, atoms: Atoms) -> tuple[dgl.DGLGraph, list]:
images[exclude_self],
bond_dist[exclude_self],
)
u, v = tensor(src_id), tensor(dst_id)
g = dgl.graph((u, v))
g.edata["pbc_offset"] = torch.tensor(images)
g.edata["lattice"] = tensor(np.stack([lattice_matrix for _ in range(g.num_edges())]))
g.ndata["attr"] = tensor(Z)
g.ndata["node_type"] = tensor(np.hstack([[element_types.index(i.symbol)] for i in atoms]))
g.ndata["pos"] = tensor(cart_coords)
g.ndata["volume"] = tensor([volume for i in range(atomic_number.shape[0])])
state_attr = [0.0, 0.0]
g.edata["pbc_offshift"] = torch.matmul(tensor(images), tensor(lattice_matrix))
return g, state_attr
return super().get_graph(
AseAtomsAdaptor().get_structure(atoms),
src_id,
dst_id,
images,
[lattice_matrix],
Z,
element_types,
cart_coords,
volume,
)


class M3GNetCalculator(Calculator):
Expand Down Expand Up @@ -160,19 +162,27 @@ def calculate(
"""
properties = properties or ["energy"]
system_changes = system_changes or all_changes
super().calculate(atoms=atoms, properties=properties, system_changes=system_changes)
super().calculate(
atoms=atoms, properties=properties, system_changes=system_changes
)
graph, state_attr_default = Atoms2Graph(self.element_types, self.cutoff).get_graph(atoms) # type: ignore
if self.state_attr is not None:
energies, forces, stresses, hessians = self.potential(graph, self.state_attr)
energies, forces, stresses, hessians = self.potential(
graph, self.state_attr
)
else:
energies, forces, stresses, hessians = self.potential(graph, state_attr_default)
energies, forces, stresses, hessians = self.potential(
graph, state_attr_default
)
self.results.update(
energy=energies.detach().cpu().numpy(),
free_energy=energies.detach().cpu().numpy(),
forces=forces.detach().cpu().numpy(),
)
if self.compute_stress:
self.results.update(stress=stresses.detach().cpu().numpy() * self.stress_weight)
self.results.update(
stress=stresses.detach().cpu().numpy() * self.stress_weight
)
if self.compute_hessian:
self.results.update(hessian=hessians.detach().cpu().numpy())

Expand Down Expand Up @@ -353,7 +363,9 @@ def __init__(
if isinstance(atoms, (Structure, Molecule)):
atoms = AseAtomsAdaptor().get_atoms(atoms)
self.atoms = atoms
self.atoms.set_calculator(M3GNetCalculator(potential=potential, state_attr=state_attr))
self.atoms.set_calculator(
M3GNetCalculator(potential=potential, state_attr=state_attr)
)

if taut is None:
taut = 100 * timestep * units.fs
Expand Down
61 changes: 37 additions & 24 deletions matgl/ext/pymatgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ def get_graph(self, mol: Molecule) -> tuple[dgl.DGLGraph, list]:
natoms = len(mol)
R = mol.cart_coords
element_types = self.element_types
Z = np.array([np.eye(len(element_types))[element_types.index(site.specie.symbol)] for site in mol])
Z = np.array(
[
np.eye(len(element_types))[element_types.index(site.specie.symbol)]
for site in mol
]
)
np.array([site.specie.Z for site in mol])
weight = mol.composition.weight / len(mol)
dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
Expand All @@ -62,17 +67,18 @@ def get_graph(self, mol: Molecule) -> tuple[dgl.DGLGraph, list]:
nbonds /= natoms
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(natoms, dtype=np.bool_)
adj = adj.tocoo()
u, v = tensor(adj.row), tensor(adj.col)
g = dgl.graph((u, v))
g = dgl.to_bidirected(g)
g.ndata["pos"] = tensor(R)
g.ndata["attr"] = tensor(Z)
g.ndata["node_type"] = tensor(np.hstack([[element_types.index(site.specie.symbol)] for site in mol]))
g.edata["pbc_offset"] = torch.zeros(g.num_edges(), 3)
g.edata["lattice"] = torch.zeros(g.num_edges(), 3, 3)
g, _ = super().get_graph(
structure=mol,
src_id=adj.row,
dst_id=adj.col,
images=torch.zeros(len(adj.row), 3),
lattice_matrix=torch.zeros(1, 3, 3),
Z=Z,
element_types=element_types,
cart_coords=R,
volume=None
)
state_attr = [weight, nbonds]
g.edata["pbc_offshift"] = torch.zeros(g.num_edges(), 3)

return g, state_attr


Expand Down Expand Up @@ -104,9 +110,15 @@ def get_graph(self, structure: Structure) -> tuple[dgl.DGLGraph, list]:
numerical_tol = 1.0e-8
pbc = np.array([1, 1, 1], dtype=int)
element_types = self.element_types
Z = np.array([np.eye(len(element_types))[element_types.index(site.specie.symbol)] for site in structure])
atomic_number = np.array([site.specie.Z for site in structure])
lattice_matrix = np.ascontiguousarray(np.array(structure.lattice.matrix), dtype=float)
Z = np.array(
[
np.eye(len(element_types))[element_types.index(site.specie.symbol)]
for site in structure
]
)
lattice_matrix = np.ascontiguousarray(
np.array(structure.lattice.matrix), dtype=float
)
volume = structure.volume
cart_coords = np.ascontiguousarray(np.array(structure.cart_coords), dtype=float)
src_id, dst_id, images, bond_dist = find_points_in_spheres(
Expand All @@ -124,13 +136,14 @@ def get_graph(self, structure: Structure) -> tuple[dgl.DGLGraph, list]:
images[exclude_self],
bond_dist[exclude_self],
)
u, v = tensor(src_id), tensor(dst_id)
g = dgl.graph((u, v))
g.edata["pbc_offset"] = torch.tensor(images)
g.edata["lattice"] = tensor(np.stack([lattice_matrix for _ in range(g.num_edges())]))
g.ndata["attr"] = tensor(Z)
g.ndata["node_type"] = tensor(np.hstack([[element_types.index(site.specie.symbol)] for site in structure]))
g.ndata["pos"] = tensor(cart_coords)
g.ndata["volume"] = tensor([volume for _ in range(atomic_number.shape[0])])
state_attr = [0.0, 0.0]
return g, state_attr
return super().get_graph(
structure,
src_id,
dst_id,
images,
[lattice_matrix],
Z,
element_types,
cart_coords,
volume,
)
49 changes: 31 additions & 18 deletions matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ def compute_3body(g: dgl.DGLGraph):

src_id, dst_id = (triple_bond_indices[:, 0], triple_bond_indices[:, 1])
l_g = dgl.graph((src_id, dst_id))
l_g.ndata["bond_dist"] = g.edata["bond_dist"]
l_g.ndata["bond_vec"] = g.edata["bond_vec"]
l_g.ndata["pbc_offset"] = g.edata["pbc_offset"]
l_g.ndata["n_triple_ij"] = n_triple_ij
n_triple_s = torch.tensor(n_triple_s, dtype=torch.int64) # type: ignore
max_three_body_id = max(np.concatenate([src_id, [-1]]))
l_g.ndata["bond_dist"] = g.edata["bond_dist"][: max_three_body_id + 1]
l_g.ndata["bond_vec"] = g.edata["bond_vec"][: max_three_body_id + 1]
l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][: max_three_body_id + 1]
l_g.ndata["n_triple_ij"] = n_triple_ij[: max_three_body_id + 1]
n_triple_s = torch.tensor(n_triple_s, dtype=torch.int64)
return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s


Expand All @@ -81,7 +82,11 @@ def compute_pair_vector_and_distance(g: dgl.DGLGraph):
bond_vec = torch.zeros(g.num_edges(), 3)
bond_vec[:, :] = (
g.ndata["pos"][g.edges()[1][:].long(), :]
+ torch.squeeze(torch.matmul(g.edata["pbc_offset"].unsqueeze(1), torch.squeeze(g.edata["lattice"])))
+ torch.squeeze(
torch.matmul(
g.edata["pbc_offset"].unsqueeze(1), torch.squeeze(g.edata["lattice"])
)
)
- g.ndata["pos"][g.edges()[0][:].long(), :]
)

Expand All @@ -103,7 +108,9 @@ def compute_theta_and_phi(edges: dgl.udf.EdgeBatch):
"""
vec1 = edges.src["bond_vec"]
vec2 = edges.dst["bond_vec"]
cosine_theta = torch.sum(vec1 * vec2, dim=1) / (torch.norm(vec1, dim=1) * torch.norm(vec2, dim=1))
cosine_theta = torch.sum(vec1 * vec2, dim=1) / (
torch.norm(vec1, dim=1) * torch.norm(vec2, dim=1)
)
return {
"cos_theta": cosine_theta,
"phi": torch.zeros_like(cosine_theta),
Expand All @@ -124,16 +131,22 @@ def create_line_graph(g_batched: dgl.DGLGraph, threebody_cutoff: float):
g_unbatched = dgl.unbatch(g_batched)
l_g_unbatched = []
for g in g_unbatched:
if g.edges()[0].size(dim=0) > 0:
valid_three_body = g.edata["bond_dist"] <= threebody_cutoff
src_id_with_three_body = g.edges()[0][valid_three_body]
dst_id_with_three_body = g.edges()[1][valid_three_body]
graph_with_three_body = dgl.graph((src_id_with_three_body, dst_id_with_three_body))
graph_with_three_body.edata["bond_dist"] = g.edata["bond_dist"][valid_three_body]
graph_with_three_body.edata["bond_vec"] = g.edata["bond_vec"][valid_three_body]
graph_with_three_body.edata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body]
if graph_with_three_body.edata["bond_dist"].size(dim=0) > 0:
l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(graph_with_three_body)
l_g_unbatched.append(l_g)
valid_three_body = g.edata["bond_dist"] <= threebody_cutoff
src_id_with_three_body = g.edges()[0][valid_three_body]
dst_id_with_three_body = g.edges()[1][valid_three_body]
graph_with_three_body = dgl.graph(
(src_id_with_three_body, dst_id_with_three_body)
)
graph_with_three_body.edata["bond_dist"] = g.edata["bond_dist"][
valid_three_body
]
graph_with_three_body.edata["bond_vec"] = g.edata["bond_vec"][valid_three_body]
graph_with_three_body.edata["pbc_offset"] = g.edata["pbc_offset"][
valid_three_body
]
l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(
graph_with_three_body
)
l_g_unbatched.append(l_g)
l_g_batched = dgl.batch(l_g_unbatched)
return l_g_batched
57 changes: 54 additions & 3 deletions matgl/graph/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,69 @@
import abc
from typing import TYPE_CHECKING

import numpy as np
import torch
from dgl.backend import tensor
from pymatgen.core import Molecule, Structure
from pymatgen.optimization.neighbors import find_points_in_spheres
if TYPE_CHECKING:
import dgl


class GraphConverter(metaclass=abc.ABCMeta):
"""Abstract base class for converters from input crystals/molecules to graphs."""

# def __init__(
# self,
# element_types: tuple[str, ...],
# cutoff: float = 5.0,
# ):

@abc.abstractmethod
def get_graph(self, structure) -> tuple[dgl.DGLGraph, list]:
def get_graph(
self,
structure,
src_id,
dst_id,
images,
lattice_matrix,
Z,
element_types,
cart_coords,
volume=None,
) -> tuple[dgl.DGLGraph, list]:
"""Args:
structure: Input crystals or molecule.

structure: Input crystals or molecule of pymatgen structure or molecule types.
src_id: site indices for starting point of bonds.
dst_id: site indices for destination point of bonds.
images: the periodic image offsets for the bonds.
lattice_matrix: lattice information of the structure.
Z: Atomic number information of all atoms in the structure.
element_types: Element symbols of all atoms in the structure.
cart_coords: Cartisian coordinates of all atoms in the structure.
volume: Volume of the structure.
Returns:
DGLGraph object, state_attr
"""
isolated_atoms = list(set(range(len(structure))).difference(src_id))
if not isolated_atoms:
u, v = tensor(src_id), tensor(dst_id)
else:
u, v = tensor(np.concatenate([src_id, isolated_atoms])), tensor(
np.concatenate([dst_id, isolated_atoms])
)
images = np.concatenate(
[images, np.repeat([[1.0, 0.0, 0.0]], len(isolated_atoms), axis=0)]
)
g = dgl.graph((u, v))
g.edata["pbc_offset"] = torch.tensor(images)
g.edata["lattice"] = tensor(np.repeat(lattice_matrix, g.num_edges(), axis=0))
g.ndata["attr"] = tensor(Z)
g.ndata["node_type"] = tensor(
np.hstack([[element_types.index(site.specie.symbol)] for site in structure])
)
g.ndata["pos"] = tensor(cart_coords)
g.ndata["volume"] = tensor([volume] * g.num_nodes())
state_attr = [0.0, 0.0]
g.edata["pbc_offshift"] = torch.matmul(tensor(images), tensor(lattice_matrix[0]))
return g, state_attr
Loading