diff --git a/docs/source/conf.py b/docs/source/conf.py index b5e1a8e7..7b968d57 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -200,4 +200,5 @@ ("py:class", "ellipsis"), ("py:class", "janus_core.helpers.stats.T"), ("py:class", "phonopy.structure.atoms.PhonopyAtoms"), + ("py:class", "ase.optimize.optimize.Optimizer"), ] diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index 063c1d20..d60a61ea 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -1,5 +1,6 @@ """Geometry optimization.""" +from logging import Logger from pathlib import Path from typing import Any, Callable, Optional import warnings @@ -7,6 +8,7 @@ from ase import Atoms from ase.io import read, write from ase.optimize import LBFGS +from ase.optimize.optimize import Optimizer try: from ase.filters import FrechetCellFilter as DefaultFilter @@ -20,8 +22,61 @@ from janus_core.helpers.utils import none_to_dict, spacegroup -def optimize( - # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements +def set_optimizer( + struct: Atoms, + filter_func: Optional[Callable] = DefaultFilter, + filter_kwargs: Optional[dict[str, Any]] = None, + optimizer: Callable = LBFGS, + opt_kwargs: Optional[ASEOptArgs] = None, + logger: Optional[Logger] = None, +) -> tuple[Optimizer, Optional[Atoms]]: + """ + Set optimizer for geometry optimisation. + + Parameters + ---------- + struct : Atoms + Atoms object to optimize geometry for. + filter_func : Optional[callable] + Apply constraints to atoms through ASE filter function. + Default is `FrechetCellFilter` if available otherwise `ExpCellFilter`. + filter_kwargs : Optional[dict[str, Any]] + Keyword arguments to pass to filter_func. Default is {}. + optimizer : callable + ASE optimization function. Default is `LBFGS`. + opt_kwargs : Optional[ASEOptArgs] + Keyword arguments to pass to optimizer. Default is {}. + logger : Optional[Logger] + Logger instance. Default is None. + + Returns + ------- + tuple[Optimizer, Optional[Atoms]] + Optimizer and options the filtered atoms structure. + """ + [filter_kwargs, opt_kwargs] = none_to_dict([filter_kwargs, opt_kwargs]) + filtered_struct = None + if filter_func is not None: + filtered_struct = filter_func(struct, **filter_kwargs) + dyn = optimizer(filtered_struct, **opt_kwargs) + if logger: + logger.info("Using filter %s", filter_func.__name__) + logger.info("Using optimizer %s", optimizer.__name__) + if "hydrostatic_strain" in filter_kwargs: + logger.info( + "hydrostatic_strain: %s", filter_kwargs["hydrostatic_strain"] + ) + if "constant_volume" in filter_kwargs: + logger.info("constant_volume: %s", filter_kwargs["constant_volume"]) + if "scalar_pressure" in filter_kwargs: + logger.info("scalar_pressure: %s", filter_kwargs["scalar_pressure"]) + else: + dyn = optimizer(struct, **opt_kwargs) + + return (dyn, filtered_struct) + + +def optimize( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches struct: Atoms, fmax: float = 0.1, steps: int = 1000, @@ -79,8 +134,8 @@ def optimize( struct: Atoms Structure with geometry optimized. """ - [filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs, log_kwargs] = none_to_dict( - [filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs, log_kwargs] + [opt_kwargs, write_kwargs, traj_kwargs, log_kwargs] = none_to_dict( + [opt_kwargs, write_kwargs, traj_kwargs, log_kwargs] ) write_kwargs.setdefault( @@ -109,23 +164,9 @@ def optimize( if logger: logger.info(message) - if filter_func is not None: - filtered_struct = filter_func(struct, **filter_kwargs) - dyn = optimizer(filtered_struct, **opt_kwargs) - if logger: - logger.info("Using filter %s", filter_func.__name__) - logger.info("Using optimizer %s", optimizer.__name__) - if "hydrostatic_strain" in filter_kwargs: - logger.info( - "hydrostatic_strain: %s", filter_kwargs["hydrostatic_strain"] - ) - if "constant_volume" in filter_kwargs: - logger.info("constant_volume: %s", filter_kwargs["constant_volume"]) - if "scalar_pressure" in filter_kwargs: - logger.info("scalar_pressure: %s", filter_kwargs["scalar_pressure"]) - - else: - dyn = optimizer(struct, **opt_kwargs) + dyn, filtered_struct = set_optimizer( + struct, filter_func, filter_kwargs, optimizer, opt_kwargs, logger + ) if logger: logger.info("Starting geometry optimization") diff --git a/janus_core/cli/md.py b/janus_core/cli/md.py index c793b9ee..c5296715 100644 --- a/janus_core/cli/md.py +++ b/janus_core/cli/md.py @@ -342,17 +342,14 @@ def md( for key in ["thermostat_time", "barostat_time", "bulk_modulus", "pressure"]: del dyn_kwargs[key] dyn = NVT(**dyn_kwargs) - - if ensemble == "npt": + elif ensemble == "npt": del dyn_kwargs["friction"] dyn = NPT(**dyn_kwargs) - - if ensemble == "nph": + elif ensemble == "nph": for key in ["friction", "barostat_time"]: del dyn_kwargs[key] dyn = NPH(**dyn_kwargs) - - if ensemble == "nve": + elif ensemble == "nve": for key in [ "thermostat_time", "barostat_time", @@ -362,12 +359,12 @@ def md( ]: del dyn_kwargs[key] dyn = NVE(**dyn_kwargs) - - if ensemble == "nvt-nh": + elif ensemble == "nvt-nh": for key in ["barostat_time", "bulk_modulus", "pressure", "friction"]: del dyn_kwargs[key] dyn = NVT_NH(**dyn_kwargs) - + else: + raise ValueError(f"Unsupported Ensemble ({ensemble})") # Store inputs for yaml summary inputs = dyn_kwargs | {"ensemble": ensemble} diff --git a/janus_core/helpers/log.py b/janus_core/helpers/log.py index c2ab7731..13f82c8d 100644 --- a/janus_core/helpers/log.py +++ b/janus_core/helpers/log.py @@ -92,7 +92,7 @@ def config_logger( capture_warnings: bool = True, filemode: Literal["r", "w", "a", "x", "r+", "w+", "a+", "x+"] = "w", force: bool = True, -): +) -> Optional[logging.Logger]: """ Configure logger with yaml-styled format.