Skip to content

Commit

Permalink
separate optimizer from the optimize to be able to use geom opt in zn…
Browse files Browse the repository at this point in the history
…draw, also keey pylint happy in md
  • Loading branch information
alinelena committed May 29, 2024
1 parent 404e6c5 commit 4fd03fa
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 27 deletions.
76 changes: 59 additions & 17 deletions janus_core/calculations/geom_opt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Geometry optimization."""

from logging import Logger
from pathlib import Path
from typing import Any, Callable, Optional
import warnings

from ase import Atoms
from ase.io import read, write
from ase.optimize import LBFGS
from ase.optimize.optimize import Optimizer

try:
from ase.filters import FrechetCellFilter as DefaultFilter
Expand All @@ -20,6 +22,58 @@
from janus_core.helpers.utils import none_to_dict, spacegroup


def set_optimizer(
struct: Atoms,
filter_func: Optional[Callable] = DefaultFilter,
filter_kwargs: Optional[dict[str, Any]] = None,
optimizer: Callable = LBFGS,
opt_kwargs: Optional[ASEOptArgs] = None,
logger: Optional[Logger] = None,
) -> tuple[Optimizer, Optional[Atoms]]:
"""
Set optimizer for geometry optimisation.
Parameters
----------
struct : Atoms
Atoms object to optimize geometry for.
filter_func : Optional[callable]
Apply constraints to atoms through ASE filter function.
Default is `FrechetCellFilter` if available otherwise `ExpCellFilter`.
filter_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to filter_func. Default is {}.
optimizer : callable
ASE optimization function. Default is `LBFGS`.
opt_kwargs : Optional[ASEOptArgs]
Keyword arguments to pass to optimizer. Default is {}.
logger : Optional[Logger]
Logger instance. Default is None.
Returns
-------
tuple[Optimizer, Optional[Atoms]]
Optimizer and options the filtered atoms structure.
"""
[filter_kwargs, opt_kwargs] = none_to_dict([filter_kwargs, opt_kwargs])
filtered_struct = None
if filter_func is not None:
filtered_struct = filter_func(struct, **filter_kwargs)
dyn = optimizer(filtered_struct, **opt_kwargs)
if logger:
logger.info("Using filter %s", filter_func.__name__)
logger.info("Using optimizer %s", optimizer.__name__)
if "hydrostatic_strain" in filter_kwargs:
logger.info(
"hydrostatic_strain: %s", filter_kwargs["hydrostatic_strain"]
)
if "constant_volume" in filter_kwargs:
logger.info("constant_volume: %s", filter_kwargs["constant_volume"])
else:
dyn = optimizer(struct, **opt_kwargs)

return (dyn, filtered_struct)


def optimize( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches
struct: Atoms,
fmax: float = 0.1,
Expand Down Expand Up @@ -78,8 +132,8 @@ def optimize( # pylint: disable=too-many-arguments,too-many-locals,too-many-bra
struct: Atoms
Structure with geometry optimized.
"""
[filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs, log_kwargs] = none_to_dict(
[filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs, log_kwargs]
[write_kwargs, traj_kwargs, log_kwargs] = none_to_dict(
[write_kwargs, traj_kwargs, log_kwargs]
)

write_kwargs.setdefault(
Expand Down Expand Up @@ -108,21 +162,9 @@ def optimize( # pylint: disable=too-many-arguments,too-many-locals,too-many-bra
if logger:
logger.info(message)

if filter_func is not None:
filtered_struct = filter_func(struct, **filter_kwargs)
dyn = optimizer(filtered_struct, **opt_kwargs)
if logger:
logger.info("Using filter %s", filter_func.__name__)
logger.info("Using optimizer %s", optimizer.__name__)
if "hydrostatic_strain" in filter_kwargs:
logger.info(
"hydrostatic_strain: %s", filter_kwargs["hydrostatic_strain"]
)
if "constant_volume" in filter_kwargs:
logger.info("constant_volume: %s", filter_kwargs["constant_volume"])

else:
dyn = optimizer(struct, **opt_kwargs)
dyn, filtered_struct = set_optimizer(
struct, filter_func, filter_kwargs, optimizer, opt_kwargs, logger
)

if logger:
logger.info("Starting geometry optimization")
Expand Down
15 changes: 6 additions & 9 deletions janus_core/cli/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,14 @@ def md(
for key in ["thermostat_time", "barostat_time", "bulk_modulus", "pressure"]:
del dyn_kwargs[key]
dyn = NVT(**dyn_kwargs)

if ensemble == "npt":
elif ensemble == "npt":
del dyn_kwargs["friction"]
dyn = NPT(**dyn_kwargs)

if ensemble == "nph":
elif ensemble == "nph":
for key in ["friction", "barostat_time"]:
del dyn_kwargs[key]
dyn = NPH(**dyn_kwargs)

if ensemble == "nve":
elif ensemble == "nve":
for key in [
"thermostat_time",
"barostat_time",
Expand All @@ -362,12 +359,12 @@ def md(
]:
del dyn_kwargs[key]
dyn = NVE(**dyn_kwargs)

if ensemble == "nvt-nh":
elif ensemble == "nvt-nh":
for key in ["barostat_time", "bulk_modulus", "pressure", "friction"]:
del dyn_kwargs[key]
dyn = NVT_NH(**dyn_kwargs)

else:
raise ValueError(f"Unsupported Ensemble ({ensemble})")
# Store inputs for yaml summary
inputs = dyn_kwargs | {"ensemble": ensemble}

Expand Down
2 changes: 1 addition & 1 deletion janus_core/helpers/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def config_logger(
capture_warnings: bool = True,
filemode: Literal["r", "w", "a", "x", "r+", "w+", "a+", "x+"] = "w",
force: bool = True,
):
) -> Optional[logging.Logger]:
"""
Configure logger with yaml-styled format.
Expand Down

0 comments on commit 4fd03fa

Please sign in to comment.