Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training and prediction of structures without 3-body interactions #92

Merged
merged 28 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ee78198
Update README.md
JiQi535 Jun 18, 2023
0e31d9f
allow structures with isolated atoms and/or two-body regions
Jun 20, 2023
8b76d10
Merge branch 'materialsvirtuallab:main' into main
JiQi535 Jun 20, 2023
33e6c6b
Merge branch 'main' of https://github.com/JiQi535/m3gnet-dgl
Jun 20, 2023
eb5b2b1
correction for max_three_body_id for graph with only two-body interac…
Jun 21, 2023
99f1863
move graph construction from ext to matgl/graph/converters.py
Jun 21, 2023
e39112e
Merge branch 'main' of https://github.com/materialsvirtuallab/matgl i…
Jun 21, 2023
a80fa1a
Merge branch 'materialsvirtuallab-main'
Jun 21, 2023
f52daad
Merge branch 'main' of https://github.com/materialsvirtuallab/matgl
Jun 21, 2023
4351277
add three_body_id and ensure correct training for isolated and 2-body…
Jun 24, 2023
67b898b
Merge branch 'main' of https://github.com/materialsvirtuallab/matgl
Jun 24, 2023
a690ce1
remove artifical bond and allow nodes not associated with edges in dg…
Jun 24, 2023
7c77cd2
clean up comments
Jun 24, 2023
2dd9f89
clean up pylint
Jun 24, 2023
2fda8cb
pylint
Jun 24, 2023
3add061
pylint
Jun 24, 2023
fddde5e
correct num_nodes for structures with part of atoms having 3-b intera…
Jun 24, 2023
629ccbd
tests from Kenko for isolated atoms and two-body regions
Jun 25, 2023
113f207
clean up
Jun 25, 2023
540a879
black
Jun 25, 2023
d67a8bc
lint
Jun 25, 2023
7854ebd
lint
Jun 25, 2023
8d6b839
isort
Jun 25, 2023
93278f5
consistent to compute_3body
Jun 25, 2023
afedda3
ruff
Jun 25, 2023
6c9cb65
black
Jun 25, 2023
bc18944
mypy
Jun 25, 2023
282ed93
sorry, black again
Jun 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions matgl/ext/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
import sys
from typing import TYPE_CHECKING

import dgl
import numpy as np
import torch

if TYPE_CHECKING:
import dgl

from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes
from ase.constraints import ExpCellFilter
Expand All @@ -22,7 +25,6 @@
from ase.optimize.lbfgs import LBFGS, LBFGSLineSearch
from ase.optimize.mdmin import MDMin
from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG
from dgl.backend import tensor
from pymatgen.core.structure import Molecule, Structure
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.optimization.neighbors import find_points_in_spheres
Expand Down Expand Up @@ -79,7 +81,6 @@ def get_graph(self, atoms: Atoms) -> tuple[dgl.DGLGraph, list]:
pbc = np.array([1, 1, 1], dtype=int)
element_types = self.element_types
Z = np.array([np.eye(len(element_types))[element_types.index(i.symbol)] for i in atoms])
atomic_number = np.array(atoms.get_atomic_numbers())
lattice_matrix = np.ascontiguousarray(np.array(atoms.get_cell()), dtype=float)
volume = atoms.get_volume()
cart_coords = np.ascontiguousarray(np.array(atoms.get_positions()), dtype=float)
Expand All @@ -98,16 +99,17 @@ def get_graph(self, atoms: Atoms) -> tuple[dgl.DGLGraph, list]:
images[exclude_self],
bond_dist[exclude_self],
)
u, v = tensor(src_id), tensor(dst_id)
g = dgl.graph((u, v))
g.edata["pbc_offset"] = torch.tensor(images)
g.edata["lattice"] = tensor(np.stack([lattice_matrix for _ in range(g.num_edges())]))
g.ndata["attr"] = tensor(Z)
g.ndata["node_type"] = tensor(np.hstack([[element_types.index(i.symbol)] for i in atoms]))
g.ndata["pos"] = tensor(cart_coords)
g.ndata["volume"] = tensor([volume for i in range(atomic_number.shape[0])])
state_attr = [0.0, 0.0]
g.edata["pbc_offshift"] = torch.matmul(tensor(images), tensor(lattice_matrix))
g, state_attr = super().get_graph_from_processed_structure(
AseAtomsAdaptor().get_structure(atoms),
src_id,
dst_id,
images,
[lattice_matrix],
Z,
element_types,
cart_coords,
)
g.ndata["volume"] = torch.tensor([volume] * g.num_nodes())
return g, state_attr


Expand Down
48 changes: 26 additions & 22 deletions matgl/ext/pymatgen.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Interface with pymatgen objects."""
from __future__ import annotations

import dgl
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import dgl

import numpy as np
import scipy.sparse as sp
import torch
from dgl.backend import tensor
from pymatgen.core import Element, Molecule, Structure
from pymatgen.optimization.neighbors import find_points_in_spheres

Expand Down Expand Up @@ -62,17 +65,17 @@ def get_graph(self, mol: Molecule) -> tuple[dgl.DGLGraph, list]:
nbonds /= natoms
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(natoms, dtype=np.bool_)
adj = adj.tocoo()
u, v = tensor(adj.row), tensor(adj.col)
g = dgl.graph((u, v))
g = dgl.to_bidirected(g)
g.ndata["pos"] = tensor(R)
g.ndata["attr"] = tensor(Z)
g.ndata["node_type"] = tensor(np.hstack([[element_types.index(site.specie.symbol)] for site in mol]))
g.edata["pbc_offset"] = torch.zeros(g.num_edges(), 3)
g.edata["lattice"] = torch.zeros(g.num_edges(), 3, 3)
g, _ = super().get_graph_from_processed_structure(
structure=mol,
src_id=adj.row,
dst_id=adj.col,
images=torch.zeros(len(adj.row), 3),
lattice_matrix=torch.zeros(1, 3, 3),
Z=Z,
element_types=element_types,
cart_coords=R,
)
state_attr = [weight, nbonds]
g.edata["pbc_offshift"] = torch.zeros(g.num_edges(), 3)

return g, state_attr


Expand Down Expand Up @@ -105,7 +108,6 @@ def get_graph(self, structure: Structure) -> tuple[dgl.DGLGraph, list]:
pbc = np.array([1, 1, 1], dtype=int)
element_types = self.element_types
Z = np.array([np.eye(len(element_types))[element_types.index(site.specie.symbol)] for site in structure])
atomic_number = np.array([site.specie.Z for site in structure])
lattice_matrix = np.ascontiguousarray(np.array(structure.lattice.matrix), dtype=float)
volume = structure.volume
cart_coords = np.ascontiguousarray(np.array(structure.cart_coords), dtype=float)
Expand All @@ -124,13 +126,15 @@ def get_graph(self, structure: Structure) -> tuple[dgl.DGLGraph, list]:
images[exclude_self],
bond_dist[exclude_self],
)
u, v = tensor(src_id), tensor(dst_id)
g = dgl.graph((u, v))
g.edata["pbc_offset"] = torch.tensor(images)
g.edata["lattice"] = tensor(np.stack([lattice_matrix for _ in range(g.num_edges())]))
g.ndata["attr"] = tensor(Z)
g.ndata["node_type"] = tensor(np.hstack([[element_types.index(site.specie.symbol)] for site in structure]))
g.ndata["pos"] = tensor(cart_coords)
g.ndata["volume"] = tensor([volume for _ in range(atomic_number.shape[0])])
state_attr = [0.0, 0.0]
g, state_attr = super().get_graph_from_processed_structure(
structure,
src_id,
dst_id,
images,
[lattice_matrix],
Z,
element_types,
cart_coords,
)
g.ndata["volume"] = torch.tensor([volume] * g.num_nodes())
return g, state_attr
30 changes: 15 additions & 15 deletions matgl/graph/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ def compute_3body(g: dgl.DGLGraph):

src_id, dst_id = (triple_bond_indices[:, 0], triple_bond_indices[:, 1])
l_g = dgl.graph((src_id, dst_id))
l_g.ndata["bond_dist"] = g.edata["bond_dist"]
l_g.ndata["bond_vec"] = g.edata["bond_vec"]
l_g.ndata["pbc_offset"] = g.edata["pbc_offset"]
l_g.ndata["n_triple_ij"] = n_triple_ij
three_body_id = np.unique(triple_bond_indices)
max_three_body_id = max(np.concatenate([three_body_id + 1, [0]]))
l_g.ndata["bond_dist"] = g.edata["bond_dist"][:max_three_body_id]
l_g.ndata["bond_vec"] = g.edata["bond_vec"][:max_three_body_id]
l_g.ndata["pbc_offset"] = g.edata["pbc_offset"][:max_three_body_id]
l_g.ndata["n_triple_ij"] = n_triple_ij[:max_three_body_id]
n_triple_s = torch.tensor(n_triple_s, dtype=torch.int64) # type: ignore
return l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s

Expand Down Expand Up @@ -124,16 +126,14 @@ def create_line_graph(g_batched: dgl.DGLGraph, threebody_cutoff: float):
g_unbatched = dgl.unbatch(g_batched)
l_g_unbatched = []
for g in g_unbatched:
if g.edges()[0].size(dim=0) > 0:
valid_three_body = g.edata["bond_dist"] <= threebody_cutoff
src_id_with_three_body = g.edges()[0][valid_three_body]
dst_id_with_three_body = g.edges()[1][valid_three_body]
graph_with_three_body = dgl.graph((src_id_with_three_body, dst_id_with_three_body))
graph_with_three_body.edata["bond_dist"] = g.edata["bond_dist"][valid_three_body]
graph_with_three_body.edata["bond_vec"] = g.edata["bond_vec"][valid_three_body]
graph_with_three_body.edata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body]
if graph_with_three_body.edata["bond_dist"].size(dim=0) > 0:
l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(graph_with_three_body)
l_g_unbatched.append(l_g)
valid_three_body = g.edata["bond_dist"] <= threebody_cutoff
src_id_with_three_body = g.edges()[0][valid_three_body]
dst_id_with_three_body = g.edges()[1][valid_three_body]
graph_with_three_body = dgl.graph((src_id_with_three_body, dst_id_with_three_body))
graph_with_three_body.edata["bond_dist"] = g.edata["bond_dist"][valid_three_body]
graph_with_three_body.edata["bond_vec"] = g.edata["bond_vec"][valid_three_body]
graph_with_three_body.edata["pbc_offset"] = g.edata["pbc_offset"][valid_three_body]
l_g, triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s = compute_3body(graph_with_three_body)
l_g_unbatched.append(l_g)
l_g_batched = dgl.batch(l_g_unbatched)
return l_g_batched
49 changes: 45 additions & 4 deletions matgl/graph/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import dgl
import dgl
import numpy as np
import torch
from dgl.backend import tensor


class GraphConverter(metaclass=abc.ABCMeta):
Expand All @@ -14,8 +15,48 @@ class GraphConverter(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_graph(self, structure) -> tuple[dgl.DGLGraph, list]:
"""Args:
structure: Input crystals or molecule.
structure: Input crystals or molecule.

Returns:
DGLGraph object, state_attr
"""

def get_graph_from_processed_structure(
self,
structure,
src_id,
dst_id,
images,
lattice_matrix,
Z,
element_types,
cart_coords,
) -> tuple[dgl.DGLGraph, list]:
"""Construct a dgl graph from processed structure and bond information.

Args:
structure: Input crystals or molecule of pymatgen structure or molecule types.
src_id: site indices for starting point of bonds.
dst_id: site indices for destination point of bonds.
images: the periodic image offsets for the bonds.
lattice_matrix: lattice information of the structure.
Z: Atomic number information of all atoms in the structure.
element_types: Element symbols of all atoms in the structure.
cart_coords: Cartisian coordinates of all atoms in the structure.

Returns:
DGLGraph object, state_attr

"""
u, v = tensor(src_id), tensor(dst_id)
g = dgl.graph((u, v))
n_missing_node = len(structure) - g.num_nodes() # isolated atoms without bonds
g.add_nodes(n_missing_node)
g.edata["pbc_offset"] = torch.tensor(images)
g.edata["lattice"] = tensor(np.repeat(lattice_matrix, g.num_edges(), axis=0))
g.ndata["attr"] = tensor(Z)
g.ndata["node_type"] = tensor(np.hstack([[element_types.index(site.specie.symbol)] for site in structure]))
g.ndata["pos"] = tensor(cart_coords)
state_attr = [0.0, 0.0]
g.edata["pbc_offshift"] = torch.matmul(tensor(images), tensor(lattice_matrix[0]))
return g, state_attr
4 changes: 2 additions & 2 deletions matgl/graph/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def __getitem__(self, idx: int):
self.line_graphs[idx],
self.state_attr[idx],
self.energies[idx],
torch.tensor(self.forces[idx]),
torch.tensor(self.stresses[idx]), # type: ignore
torch.tensor(self.forces[idx]).float(),
torch.tensor(self.stresses[idx]).float(), # type: ignore
)

def __len__(self):
Expand Down
15 changes: 8 additions & 7 deletions matgl/layers/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,11 @@ class SphericalBesselFunction:
"""Calculate the spherical Bessel function based on sympy + pytorch implementations."""

def __init__(self, max_l: int, max_n: int = 5, cutoff: float = 5.0, smooth: bool = False):
"""
Args:
max_l: int, max order (excluding l)
max_n: int, max number of roots used in each l
cutoff: float, cutoff radius
smooth: Whether to smooth the function.
"""Args:
max_l: int, max order (excluding l)
max_n: int, max number of roots used in each l
cutoff: float, cutoff radius
smooth: Whether to smooth the function.
"""
self.max_l = max_l
self.max_n = max_n
Expand Down Expand Up @@ -108,6 +107,8 @@ def _call_smooth_sbf(self, r):
return torch.t(torch.stack(results))

def _call_sbf(self, r):
r_c = r.clone()
r_c[r_c > self.cutoff] = self.cutoff
roots = SPHERICAL_BESSEL_ROOTS[: self.max_l, : self.max_n]

results = []
Expand All @@ -117,7 +118,7 @@ def _call_sbf(self, r):
func = self.funcs[i]
func_add1 = self.funcs[i + 1]
results.append(
func(r[:, None] * root[None, :] / self.cutoff) * factor / torch.abs(func_add1(root[None, :]))
func(r_c[:, None] * root[None, :] / self.cutoff) * factor / torch.abs(func_add1(root[None, :]))
)
return torch.cat(results, axis=1)

Expand Down
8 changes: 3 additions & 5 deletions matgl/layers/_three_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
import torch
from torch import nn

from matgl.utils.maths import (
_block_repeat,
get_segment_indices_from_n,
scatter_sum,
)
from matgl.utils.maths import _block_repeat, get_segment_indices_from_n, scatter_sum

if TYPE_CHECKING:
import dgl
Expand Down Expand Up @@ -70,6 +66,8 @@ def forward(
num_segments=graph.num_edges(),
dim=0,
)
if not new_bonds.data.shape[0]:
return edge_feat
edge_feat_updated = edge_feat + self.update_network_bond(new_bonds)
return edge_feat_updated

Expand Down
Loading