Skip to content

Commit

Permalink
CHGNetCalculator add kwarg task: PredTask = "efsm"
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Nov 16, 2024
1 parent 0da2d15 commit cbbbf99
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 24 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
default_stages: [commit]
default_stages: [pre-commit]

default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.7.4
hooks:
- id: ruff
args: [--fix]
Expand All @@ -28,11 +28,11 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
stages: [commit, commit-msg]
stages: [pre-commit, commit-msg]
args: [--check-filenames]

- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
rev: 0.8.0
hooks:
- id: nbstripout
args: [--drop-empty-cells, --keep-output]
Expand All @@ -48,7 +48,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.12.0
rev: v9.15.0
hooks:
- id: eslint
types: [file]
Expand Down
38 changes: 25 additions & 13 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from ase.optimize.optimize import Optimizer
from typing_extensions import Self

from chgnet import PredTask

# We would like to thank M3GNet develop team for this module
# source: https://github.com/materialsvirtuallab/m3gnet

Expand All @@ -59,7 +61,7 @@ def __init__(
*,
use_device: str | None = None,
check_cuda_mem: bool = False,
stress_weight: float | None = 1 / 160.21766208,
stress_weight: float = units.GPa, # GPa to eV/A^3
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
return_site_energies: bool = False,
**kwargs,
Expand Down Expand Up @@ -124,6 +126,7 @@ def calculate(
atoms: Atoms | None = None,
properties: list | None = None,
system_changes: list | None = None,
task: PredTask = "efsm",
) -> None:
"""Calculate various properties of the atoms using CHGNet.
Expand All @@ -133,6 +136,8 @@ def calculate(
Default is all properties.
system_changes (list | None): The changes made to the system.
Default is all changes.
task (PredTask): The task to perform. One of "e", "ef", "em", "efs", "efsm".
Default = "efsm"
"""
properties = properties or all_properties
system_changes = system_changes or all_changes
Expand All @@ -147,23 +152,30 @@ def calculate(
graph = self.model.graph_converter(structure)
model_prediction = self.model.predict_graph(
graph.to(self.device),
task="efsm",
task=task,
return_crystal_feas=True,
return_site_energies=self.return_site_energies,
)

# Convert Result
factor = 1 if not self.model.is_intensive else structure.composition.num_atoms
self.results.update(
energy=model_prediction["e"] * factor,
forces=model_prediction["f"],
free_energy=model_prediction["e"] * factor,
magmoms=model_prediction["m"],
stress=model_prediction["s"] * self.stress_weight,
crystal_fea=model_prediction["crystal_fea"],
extensive_factor = (
1 if not self.model.is_intensive else structure.composition.num_atoms
)
key_map = dict(
e=("energy", extensive_factor),
f=("forces", 1),
m=("magmoms", 1),
s=("stress", self.stress_weight),
)
self.results |= {
long_key: model_prediction[key] * factor
for key, (long_key, factor) in key_map.items()
if key in model_prediction
}
self.results["free_energy"] = self.results["energy"]
self.results["crystal_fea"] = model_prediction["crystal_fea"]
if self.return_site_energies:
self.results.update(energies=model_prediction["site_energies"])
self.results["energies"] = model_prediction["site_energies"]


class StructOptimizer:
Expand All @@ -174,7 +186,7 @@ def __init__(
model: CHGNet | CHGNetCalculator | None = None,
optimizer_class: Optimizer | str | None = "FIRE",
use_device: str | None = None,
stress_weight: float = 1 / 160.21766208,
stress_weight: float = units.GPa,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
) -> None:
"""Provide a trained CHGNet model and an optimizer to relax crystal structures.
Expand Down Expand Up @@ -773,7 +785,7 @@ def __init__(
model: CHGNet | CHGNetCalculator | None = None,
optimizer_class: Optimizer | str | None = "FIRE",
use_device: str | None = None,
stress_weight: float = 1 / 160.21766208,
stress_weight: float = units.GPa,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "error",
) -> None:
"""Initialize a structure optimizer object for calculation of bulk modulus.
Expand Down
9 changes: 6 additions & 3 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import os
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, get_args

import torch
from pymatgen.core import Structure
from torch import Tensor, nn

from chgnet import PredTask
from chgnet.graph import CrystalGraph, CrystalGraphConverter
from chgnet.graph.crystalgraph import TORCH_DTYPE
from chgnet.model.composition_model import AtomRef
Expand All @@ -27,7 +28,6 @@
if TYPE_CHECKING:
from typing_extensions import Self

from chgnet import PredTask

module_dir = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -603,7 +603,7 @@ def predict_graph(
Args:
graph (CrystalGraph | Sequence[CrystalGraph]): CrystalGraph(s) to predict.
task (str): can be 'e' 'ef', 'em', 'efs', 'efsm'
task (PredTask): one of 'e', 'ef', 'em', 'efs', 'efsm'
Default = "efsm"
return_site_energies (bool): whether to return per-site energies.
Default = False
Expand All @@ -626,6 +626,9 @@ def predict_graph(
raise TypeError(
f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs"
)
valid_tasks = get_args(PredTask)
if task not in valid_tasks:
raise ValueError(f"Invalid {task=}. Must be one of {valid_tasks}.")

model_device = next(self.parameters()).device

Expand Down
28 changes: 26 additions & 2 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import pickle
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, get_args

import numpy as np
import pytest
Expand All @@ -22,7 +22,7 @@
from chgnet.graph import CrystalGraphConverter
from chgnet.model import StructOptimizer
from chgnet.model.dynamics import CHGNetCalculator, EquationOfState, MolecularDynamics
from chgnet.model.model import CHGNet
from chgnet.model.model import CHGNet, PredTask

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -314,3 +314,27 @@ def test_md_crystal_feas_log(tmp_path: Path, monkeypatch: MonkeyPatch):
assert crystal_feas[0][1] == approx(-1.4285042, abs=1e-5)
assert crystal_feas[10][0] == approx(-0.0020592688, abs=1e-5)
assert crystal_feas[10][1] == approx(-1.4284436, abs=1e-5)


@pytest.mark.parametrize("task", [*get_args(PredTask)])
def test_calculator_task_valid(task: PredTask):
"""Test that the task kwarg of CHGNetCalculator.calculate() works correctly."""
key_map = dict(e="energy", f="forces", m="magmoms", s="stress")
calculator = CHGNetCalculator()
atoms = AseAtomsAdaptor.get_atoms(structure)
atoms.calc = calculator

calculator.calculate(atoms=atoms, task=task)

for key, prop in key_map.items():
assert (prop in calculator.results) == (key in task)


def test_calculator_task_invalid():
"""Test that invalid task raises ValueError."""
calculator = CHGNetCalculator()
atoms = AseAtomsAdaptor.get_atoms(structure)
atoms.calc = calculator

with pytest.raises(ValueError, match="Invalid task='invalid'."):
calculator.calculate(atoms=atoms, task="invalid")
4 changes: 3 additions & 1 deletion tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def test_relaxation(
assert {*traj.__dict__} == {
*"atoms energies forces stresses magmoms atom_positions cells".split()
}
assert len(traj) == 2 if algorithm == "legacy" else 4
assert len(traj) == (
2 if algorithm == "legacy" else 4
), f"{len(traj)=}, {algorithm=}"

# make sure final structure is more relaxed than initial one
assert traj.energies[-1] == pytest.approx(-58.94209, rel=1e-4)
Expand Down

0 comments on commit cbbbf99

Please sign in to comment.