Skip to content

Commit

Permalink
Fix sequences of atoms (stfc#245)
Browse files Browse the repository at this point in the history
* Raise error if sequence passed to single struct calcs

* Test for struct inputs

* Set default read_kwargs for CLI

* Add tests for CLI read kwargs

* Add NaCl trajetory for tests

* Update janus_core/calculations/eos.py

* Clarify reason for limiting ASE index

---------

Co-authored-by: Alin Marin Elena <[email protected]>
  • Loading branch information
ElliottKasoar and alinelena authored Aug 6, 2024
1 parent f182589 commit c2bea2a
Show file tree
Hide file tree
Showing 21 changed files with 436 additions and 21 deletions.
11 changes: 10 additions & 1 deletion janus_core/calculations/eos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Equation of State."""

from collections.abc import Sequence
from copy import copy
from typing import Any, Optional

Expand All @@ -22,7 +23,7 @@ class EoS(FileNameMixin):
Parameters
----------
struct : Atoms
Structure.
Structure to calculate equation of state for.
struct_name : Optional[str]
Name of structure. Default is None.
min_volume : float
Expand Down Expand Up @@ -147,6 +148,14 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self.write_kwargs = write_kwargs
self.log_kwargs = log_kwargs

if not isinstance(struct, Atoms):
if isinstance(struct, Sequence) and isinstance(struct[0], Atoms):
raise NotImplementedError(
"The equation of state can only be calculated for one Atoms "
"object at a time currently"
)
raise ValueError("`struct` must be an ASE Atoms object")

log_kwargs.setdefault("name", __name__)
self.logger = config_logger(**log_kwargs)
self.tracker = config_tracker(self.logger, **tracker_kwargs)
Expand Down
8 changes: 8 additions & 0 deletions janus_core/calculations/geom_opt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Geometry optimization."""

from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import warnings

Expand Down Expand Up @@ -160,6 +161,13 @@ def __init__( # pylint: disable=too-many-arguments
self.write_kwargs = write_kwargs
self.traj_kwargs = traj_kwargs

if not isinstance(struct, Atoms):
if isinstance(struct, Sequence) and isinstance(struct[0], Atoms):
raise NotImplementedError(
"Only one Atoms object at a time can currently be optimized"
)
raise ValueError("`struct` must be an ASE Atoms object")

FileNameMixin.__init__(self, self.struct, None, None)

self.write_kwargs.setdefault(
Expand Down
8 changes: 8 additions & 0 deletions janus_core/calculations/md.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=too-many-lines
"""Run molecular dynamics simulations."""

from collections.abc import Sequence
import datetime
from functools import partial
from itertools import combinations_with_replacement
Expand Down Expand Up @@ -319,6 +320,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta
self.ensemble = ensemble
self.seed = seed

if not isinstance(struct, Atoms):
if isinstance(struct, Sequence) and isinstance(struct[0], Atoms):
raise NotImplementedError(
"MD can only be run for one Atoms object at a time currently"
)
raise ValueError("`struct` must be an ASE Atoms object")

FileNameMixin.__init__(self, struct, struct_name, file_prefix, ensemble)

self.write_kwargs.setdefault(
Expand Down
9 changes: 9 additions & 0 deletions janus_core/calculations/phonons.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Phonon calculations."""

from collections.abc import Sequence
from typing import Any, Optional

from ase import Atoms
Expand Down Expand Up @@ -134,6 +135,14 @@ def __init__( # pylint: disable=too-many-arguments,disable=too-many-locals
tracker_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to `config_tracker`. Default is {}.
"""
if not isinstance(struct, Atoms):
if isinstance(struct, Sequence) and isinstance(struct[0], Atoms):
raise NotImplementedError(
"Phonons can only be calculated for one Atoms object at a time "
"currently"
)
raise ValueError("`struct` must be an ASE Atoms object")

FileNameMixin.__init__(self, struct, struct_name, file_prefix)

[minimize_kwargs, log_kwargs, tracker_kwargs] = none_to_dict(
Expand Down
7 changes: 4 additions & 3 deletions janus_core/cli/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Device,
LogPath,
ModelPath,
ReadKwargs,
ReadKwargsAll,
StructPath,
Summary,
WriteKwargs,
Expand Down Expand Up @@ -63,7 +63,7 @@ def descriptors(
),
),
] = None,
read_kwargs: ReadKwargs = None,
read_kwargs: ReadKwargsAll = None,
calc_kwargs: CalcKwargs = None,
write_kwargs: WriteKwargs = None,
log: LogPath = "descriptors.log",
Expand Down Expand Up @@ -95,7 +95,8 @@ def descriptors(
Path to save structure with calculated results. Default is inferred from name
of the structure file.
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
Keyword arguments to pass to ase.io.read. By default,
read_kwargs["index"] is ":".
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
write_kwargs : Optional[dict[str, Any]]
Expand Down
11 changes: 8 additions & 3 deletions janus_core/cli/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
LogPath,
MinimizeKwargs,
ModelPath,
ReadKwargs,
ReadKwargsFirst,
StructPath,
Summary,
WriteKwargs,
Expand All @@ -25,6 +25,7 @@
end_summary,
parse_typer_dicts,
save_struct_calc,
set_read_kwargs_index,
start_summary,
yaml_converter_callback,
)
Expand Down Expand Up @@ -70,7 +71,7 @@ def eos(
arch: Architecture = "mace_mp",
device: Device = "cpu",
model_path: ModelPath = None,
read_kwargs: ReadKwargs = None,
read_kwargs: ReadKwargsFirst = None,
calc_kwargs: CalcKwargs = None,
file_prefix: Annotated[
Optional[Path],
Expand Down Expand Up @@ -128,7 +129,8 @@ def eos(
model_path : Optional[str]
Path to MLIP model. Default is `None`.
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
Keyword arguments to pass to ase.io.read. By default,
read_kwargs["index"] is 0.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
file_prefix : Optional[PathLike]
Expand All @@ -151,6 +153,9 @@ def eos(
if not eos_type in get_args(EoSNames):
raise ValueError(f"Fit type must be one of: {get_args(EoSNames)}")

# Read only first structure by default and ensure only one image is read
set_read_kwargs_index(read_kwargs)

# Set up single point calculator
s_point = SinglePoint(
struct_path=struct,
Expand Down
11 changes: 8 additions & 3 deletions janus_core/cli/geomopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
LogPath,
MinimizeKwargs,
ModelPath,
ReadKwargs,
ReadKwargsFirst,
StructPath,
Summary,
WriteKwargs,
Expand All @@ -25,6 +25,7 @@
end_summary,
parse_typer_dicts,
save_struct_calc,
set_read_kwargs_index,
start_summary,
yaml_converter_callback,
)
Expand Down Expand Up @@ -143,7 +144,7 @@ def geomopt(
str,
Option(help="Path if saving optimization frames. [default: None]"),
] = None,
read_kwargs: ReadKwargs = None,
read_kwargs: ReadKwargsFirst = None,
calc_kwargs: CalcKwargs = None,
minimize_kwargs: MinimizeKwargs = None,
write_kwargs: WriteKwargs = None,
Expand Down Expand Up @@ -192,7 +193,8 @@ def geomopt(
traj : Optional[str]
Path if saving optimization frames. Default is None.
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
Keyword arguments to pass to ase.io.read. By default,
read_kwargs["index"] is 0.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
minimize_kwargs : Optional[dict[str, Any]]
Expand All @@ -215,6 +217,9 @@ def geomopt(
[read_kwargs, calc_kwargs, minimize_kwargs, write_kwargs]
)

# Read only first structure by default and ensure only one image is read
set_read_kwargs_index(read_kwargs)

# Set up single point calculator
s_point = SinglePoint(
struct_path=struct,
Expand Down
11 changes: 8 additions & 3 deletions janus_core/cli/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
MinimizeKwargs,
ModelPath,
PostProcessKwargs,
ReadKwargs,
ReadKwargsFirst,
StructPath,
Summary,
WriteKwargs,
Expand All @@ -27,6 +27,7 @@
end_summary,
parse_typer_dicts,
save_struct_calc,
set_read_kwargs_index,
start_summary,
yaml_converter_callback,
)
Expand Down Expand Up @@ -80,7 +81,7 @@ def md(
arch: Architecture = "mace_mp",
device: Device = "cpu",
model_path: ModelPath = None,
read_kwargs: ReadKwargs = None,
read_kwargs: ReadKwargsFirst = None,
calc_kwargs: CalcKwargs = None,
equil_steps: Annotated[
int,
Expand Down Expand Up @@ -232,7 +233,8 @@ def md(
model_path : Optional[str]
Path to MLIP model. Default is `None`.
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
Keyword arguments to pass to ase.io.read. By default,
read_kwargs["index"] is 0.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
equil_steps : int
Expand Down Expand Up @@ -330,6 +332,9 @@ def md(
if not ensemble in get_args(Ensembles):
raise ValueError(f"ensemble must be in {get_args(Ensembles)}")

# Read only first structure by default and ensure only one image is read
set_read_kwargs_index(read_kwargs)

# Set up single point calculator
s_point = SinglePoint(
struct_path=struct,
Expand Down
11 changes: 8 additions & 3 deletions janus_core/cli/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
LogPath,
MinimizeKwargs,
ModelPath,
ReadKwargs,
ReadKwargsFirst,
StructPath,
Summary,
)
Expand All @@ -24,6 +24,7 @@
end_summary,
parse_typer_dicts,
save_struct_calc,
set_read_kwargs_index,
start_summary,
yaml_converter_callback,
)
Expand Down Expand Up @@ -113,7 +114,7 @@ def phonons(
arch: Architecture = "mace_mp",
device: Device = "cpu",
model_path: ModelPath = None,
read_kwargs: ReadKwargs = None,
read_kwargs: ReadKwargsFirst = None,
calc_kwargs: CalcKwargs = None,
file_prefix: Annotated[
Optional[Path],
Expand Down Expand Up @@ -187,7 +188,8 @@ def phonons(
model_path : Optional[str]
Path to MLIP model. Default is `None`.
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
Keyword arguments to pass to ase.io.read. By default,
read_kwargs["index"] is 0.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
file_prefix : Optional[PathLike]
Expand All @@ -208,6 +210,9 @@ def phonons(
[read_kwargs, calc_kwargs, minimize_kwargs]
)

# Read only first structure by default and ensure only one image is read
set_read_kwargs_index(read_kwargs)

# Set up single point calculator
s_point = SinglePoint(
struct_path=struct,
Expand Down
7 changes: 4 additions & 3 deletions janus_core/cli/singlepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Device,
LogPath,
ModelPath,
ReadKwargs,
ReadKwargsAll,
StructPath,
Summary,
WriteKwargs,
Expand Down Expand Up @@ -59,7 +59,7 @@ def singlepoint(
),
),
] = None,
read_kwargs: ReadKwargs = None,
read_kwargs: ReadKwargsAll = None,
calc_kwargs: CalcKwargs = None,
write_kwargs: WriteKwargs = None,
log: LogPath = "singlepoint.log",
Expand Down Expand Up @@ -87,7 +87,8 @@ def singlepoint(
Path to save structure with calculated results. Default is inferred from name
of the structure file.
read_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.read. Default is {}.
Keyword arguments to pass to ase.io.read. By default,
read_kwargs["index"] is ":".
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
write_kwargs : Optional[dict[str, Any]]
Expand Down
18 changes: 16 additions & 2 deletions janus_core/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,28 @@ def __str__(self):
Device = Annotated[str, Option(help="Device to run calculations on.")]
ModelPath = Annotated[str, Option(help="Path to MLIP model. [default: None]")]

ReadKwargs = Annotated[
ReadKwargsAll = Annotated[
TyperDict,
Option(
parser=parse_dict_class,
help=(
"""
Keyword arguments to pass to ase.io.read. Must be passed as a dictionary
wrapped in quotes, e.g. "{'key' : value}". [default: "{}"]
wrapped in quotes, e.g. "{'key' : value}". [default: "{'index': ':'}"]
"""
),
metavar="DICT",
),
]

ReadKwargsFirst = Annotated[
TyperDict,
Option(
parser=parse_dict_class,
help=(
"""
Keyword arguments to pass to ase.io.read. Must be passed as a dictionary
wrapped in quotes, e.g. "{'key' : value}". [default: "{'index': 0}"]
"""
),
metavar="DICT",
Expand Down
20 changes: 20 additions & 0 deletions janus_core/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@
from janus_core.helpers.utils import dict_remove_hyphens


def set_read_kwargs_index(read_kwargs: dict[str, Any]) -> None:
"""
Set default read_kwargs["index"] and check its value is an integer.
To ensure only a single Atoms object is read, slices such as ":" are forbidden.
Parameters
----------
read_kwargs : dict[str, Any]
Keyword arguments to be passed to ase.io.read. If specified,
read_kwargs["index"] must be an integer, and if not, a default value
of 0 is set.
"""
read_kwargs.setdefault("index", 0)
try:
int(read_kwargs["index"])
except ValueError as e:
raise ValueError("`read_kwargs['index']` must be an integer") from e


def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
"""
Convert list of TyperDict objects to list of dictionaries.
Expand Down
Loading

0 comments on commit c2bea2a

Please sign in to comment.