diff --git a/docs/source/user_guide/command_line.rst b/docs/source/user_guide/command_line.rst index 14ce77de..03ee0e50 100644 --- a/docs/source/user_guide/command_line.rst +++ b/docs/source/user_guide/command_line.rst @@ -314,7 +314,7 @@ Calculate phonons with a 2x2x2 supercell, after geometry optimization (using the .. code-block:: bash - janus phonons --struct tests/data/NaCl.cif --supercell 2x2x2 --minimize --arch mace_mp --model-path small + janus phonons --struct tests/data/NaCl.cif --supercell 2 2 2 --minimize --arch mace_mp --model-path small This will save the Phonopy parameters, including displacements and force constants, to ``NaCl-phonopy.yml`` and ``NaCl-force_constants.hdf5``, @@ -324,7 +324,7 @@ Additionally, the ``--bands`` option can be added to calculate the band structur .. code-block:: bash - janus phonons --struct tests/data/NaCl.cif --supercell 2x2x2 --minimize --arch mace_mp --model-path small --bands + janus phonons --struct tests/data/NaCl.cif --supercell 2 2 2 --minimize --arch mace_mp --model-path small --bands If you need eigenvectors and group velocities written, add the ``--write-full`` option. This will generate a much larger file, but can be used to visualise phonon modes. @@ -333,7 +333,7 @@ Further calculations, including thermal properties, DOS, and PDOS, can also be c .. code-block:: bash - janus phonons --struct tests/data/NaCl.cif --supercell 2x3x4 --dos --pdos --thermal --temp-start 0 --temp-end 300 --temp-step 50 + janus phonons --struct tests/data/NaCl.cif --supercell 2 3 4 --dos --pdos --thermal --temp-start 0 --temp-end 300 --temp-step 50 This will create additional output files: ``NaCl-thermal.dat`` for the thermal properties (heat capacity, entropy, and free energy) diff --git a/janus_core/calculations/descriptors.py b/janus_core/calculations/descriptors.py index e5d3b8a8..32536e86 100644 --- a/janus_core/calculations/descriptors.py +++ b/janus_core/calculations/descriptors.py @@ -4,6 +4,8 @@ from typing import Any, Optional from ase import Atoms +from ase.calculators.calculator import Calculator +from ase.calculators.mixing import SumCalculator import numpy as np from janus_core.calculations.base import BaseCalculation @@ -163,12 +165,43 @@ def __init__( ): raise ValueError("Please attach a calculator to `struct`.") + if isinstance(self.struct, Atoms): + self._check_calculator(self.struct.calc) + if isinstance(self.struct, Sequence): + for image in self.struct: + self._check_calculator(image.calc) + # Set output file self.write_kwargs.setdefault("filename", None) self.write_kwargs["filename"] = self._build_filename( "descriptors.extxyz", filename=self.write_kwargs["filename"] ).absolute() + @staticmethod + def _check_calculator(calc: Calculator) -> None: + """ + Ensure calculator has ability to calculate descriptors. + + Parameters + ---------- + calc : Calculator + ASE Calculator to calculate descriptors. + """ + # If dispersion added to MLIP calculator, use MLIP calculator for descriptors + if isinstance(calc, SumCalculator): + if ( + len(calc.mixer.calcs) == 2 + and calc.mixer.calcs[1].name == "TorchDFTD3Calculator" + and hasattr(calc.mixer.calcs[0], "get_descriptors") + ): + calc.get_descriptors = calc.mixer.calcs[0].get_descriptors + + if not hasattr(calc, "get_descriptors") or not callable(calc.get_descriptors): + raise NotImplementedError( + "The attached calculator does not currently support calculating " + "descriptors" + ) + def run(self) -> None: """Calculate descriptors for structure(s).""" if self.logger: diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index 9c07d340..cbf51388 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -300,7 +300,8 @@ def run(self) -> None: if self.logger: self.logger.info("After optimization spacegroup: %s", s_grp) - self.logger.info("Max force: %.6f", max_force) + self.logger.info("Max force: %s", max_force) + self.logger.info("Final energy: %s", self.struct.get_potential_energy()) if not converged: warnings.warn( diff --git a/janus_core/calculations/md.py b/janus_core/calculations/md.py index 4c294301..026d7adf 100644 --- a/janus_core/calculations/md.py +++ b/janus_core/calculations/md.py @@ -509,12 +509,21 @@ def __init__( self._parse_correlations() - def _set_time_step(self): - """Set time in fs and current dynamics step to info.""" + def _set_info(self): + """Set time in fs, current dynamics step, and density to info.""" time = (self.offset * self.timestep + self.dyn.get_time()) / units.fs step = self.offset + self.dyn.nsteps self.dyn.atoms.info["time_fs"] = time self.dyn.atoms.info["step"] = step + try: + density = ( + np.sum(self.dyn.atoms.get_masses()) + / self.dyn.atoms.get_volume() + * DENS_FACT + ) + self.dyn.atoms.info["density"] = density + except ValueError: + self.dyn.atoms.info["density"] = 0.0 def _prepare_restart(self) -> None: """Prepare restart files, structure and offset.""" @@ -726,19 +735,13 @@ def get_stats(self) -> dict[str, float]: e_kin = self.dyn.atoms.get_kinetic_energy() / self.n_atoms current_temp = e_kin / (1.5 * units.kB) - self._set_time_step() + self._set_info() time_now = datetime.datetime.now() real_time = time_now - self.dyn.atoms.info["real_time"] self.dyn.atoms.info["real_time"] = time_now try: - density = ( - np.sum(self.dyn.atoms.get_masses()) - / self.dyn.atoms.get_volume() - * DENS_FACT - ) - self.dyn.atoms.info["density"] = density volume = self.dyn.atoms.get_volume() pressure = ( -np.trace( @@ -754,7 +757,6 @@ def get_stats(self) -> dict[str, float]: except ValueError: volume = 0.0 pressure = 0.0 - density = 0.0 pressure_tensor = np.zeros(6) return { @@ -765,7 +767,7 @@ def get_stats(self) -> dict[str, float]: "EKin/N": e_kin, "T": current_temp, "ETot/N": e_pot + e_kin, - "Density": density, + "Density": self.dyn.atoms.info["density"], "Volume": volume, "P": pressure, "Pxx": pressure_tensor[0], @@ -874,7 +876,7 @@ def _write_traj(self) -> None: self.dyn.nsteps > self.traj_start + self.traj_start % self.traj_every ) - self._set_time_step() + self._set_info() write_kwargs = self.write_kwargs write_kwargs["filename"] = self.traj_file write_kwargs["append"] = append @@ -895,7 +897,7 @@ def _write_final_state(self) -> None: # Append if final file has been created append = self.created_final_file - self._set_time_step() + self._set_info() write_kwargs = self.write_kwargs write_kwargs["filename"] = self.final_file write_kwargs["append"] = append @@ -998,7 +1000,7 @@ def _write_restart(self) -> None: if step > 0: write_kwargs = self.write_kwargs write_kwargs["filename"] = self._restart_file - self._set_time_step() + self._set_info() output_structs( images=self.struct, diff --git a/janus_core/calculations/phonons.py b/janus_core/calculations/phonons.py index a9299358..24bb4300 100644 --- a/janus_core/calculations/phonons.py +++ b/janus_core/calculations/phonons.py @@ -20,7 +20,7 @@ PathLike, PhononCalcs, ) -from janus_core.helpers.utils import none_to_dict, write_table +from janus_core.helpers.utils import none_to_dict, track_progress, write_table class Phonons(BaseCalculation): @@ -60,6 +60,8 @@ class Phonons(BaseCalculation): Size of supercell for calculation. Default is 2. displacement : float Displacement for force constants calculation, in A. Default is 0.01. + mesh : tuple[int, int, int] + Mesh for sampling. Default is (10, 10, 10). symmetrize : bool Whether to symmetrize force constants after calculation. Default is False. @@ -88,6 +90,8 @@ class Phonons(BaseCalculation): file_prefix : Optional[PathLike] Prefix for output filenames. Default is inferred from chemical formula of the structure. + enable_progress_bar : bool + Whether to show a progress bar during phonon calculations. Default is False. Attributes ---------- @@ -106,7 +110,7 @@ class Phonons(BaseCalculation): Calculate band structure and optionally write and plot results. write_bands(bands_file, save_plots, plot_file) Write results of band structure calculations. - calc_thermal_props(write_thermal) + calc_thermal_props(mesh, write_thermal) Calculate thermal properties and optionally write results. write_thermal_props(thermal_file) Write results of thermal properties calculations. @@ -138,6 +142,7 @@ def __init__( calcs: MaybeSequence[PhononCalcs] = (), supercell: MaybeList[int] = 2, displacement: float = 0.01, + mesh: tuple[int, int, int] = (10, 10, 10), symmetrize: bool = False, minimize: bool = False, minimize_kwargs: Optional[dict[str, Any]] = None, @@ -149,6 +154,7 @@ def __init__( write_results: bool = True, write_full: bool = True, file_prefix: Optional[PathLike] = None, + enable_progress_bar: bool = False, ) -> None: """ Initialise Phonons class. @@ -186,6 +192,8 @@ def __init__( Size of supercell for calculation. Default is 2. displacement : float Displacement for force constants calculation, in A. Default is 0.01. + mesh : tuple[int, int, int] + Mesh for sampling. Default is (10, 10, 10). symmetrize : bool Whether to symmetrize force constants after calculations. Default is False. @@ -214,11 +222,14 @@ def __init__( file_prefix : Optional[PathLike] Prefix for output filenames. Default is inferred from structure name, or chemical formula of the structure. + enable_progress_bar : bool + Whether to show a progress bar during phonon calculations. Default is False. """ (read_kwargs, minimize_kwargs) = none_to_dict((read_kwargs, minimize_kwargs)) self.calcs = calcs self.displacement = displacement + self.mesh = mesh self.symmetrize = symmetrize self.minimize = minimize self.minimize_kwargs = minimize_kwargs @@ -229,6 +240,7 @@ def __init__( self.plot_to_file = plot_to_file self.write_results = write_results self.write_full = write_full + self.enable_progress_bar = enable_progress_bar # Ensure supercell is a valid list self.supercell = [supercell] * 3 if isinstance(supercell, int) else supercell @@ -357,6 +369,11 @@ def calc_force_constants( phonon.generate_displacements(distance=self.displacement) disp_supercells = phonon.supercells_with_displacements + if self.enable_progress_bar: + disp_supercells = track_progress( + disp_supercells, "Computing displacements..." + ) + phonon.forces = [ self._calc_forces(supercell) for supercell in disp_supercells @@ -490,13 +507,18 @@ def write_bands( bplt.savefig(plot_file) def calc_thermal_props( - self, write_thermal: Optional[bool] = None, **kwargs + self, + mesh: Optional[tuple[int, int, int]] = None, + write_thermal: Optional[bool] = None, + **kwargs, ) -> None: """ Calculate thermal properties and optionally write results. Parameters ---------- + mesh : Optional[tuple[int, int, int]] + Mesh for sampling. Default is self.mesh. write_thermal : Optional[bool] Whether to write out thermal properties to file. Default is self.write_results. @@ -506,6 +528,9 @@ def calc_thermal_props( if write_thermal is None: write_thermal = self.write_results + if mesh is None: + mesh = self.mesh + # Calculate phonons if not already in results if "phonon" not in self.results: # Use general (self.write_results) setting for writing force constants @@ -515,7 +540,7 @@ def calc_thermal_props( self.logger.info("Starting thermal properties calculation") self.tracker.start_task("Thermal calculation") - self.results["phonon"].run_mesh() + self.results["phonon"].run_mesh(mesh) self.results["phonon"].run_thermal_properties( t_step=self.temp_step, t_max=self.temp_max, t_min=self.temp_min ) @@ -563,7 +588,7 @@ def write_thermal_props(self, thermal_file: Optional[PathLike] = None) -> None: def calc_dos( self, *, - mesh: MaybeList[float] = (10, 10, 10), + mesh: Optional[tuple[int, int, int]] = None, write_dos: Optional[bool] = None, **kwargs, ) -> None: @@ -572,8 +597,8 @@ def calc_dos( Parameters ---------- - mesh : MaybeList[float] - Mesh for sampling. Default is (10, 10, 10). + mesh : Optional[tuple[int, int, int]] + Mesh for sampling. Default is self.mesh. write_dos : Optional[bool] Whether to write out results to file. Default is True. **kwargs @@ -582,6 +607,9 @@ def calc_dos( if write_dos is None: write_dos = self.write_results + if mesh is None: + mesh = self.mesh + # Calculate phonons if not already in results if "phonon" not in self.results: # Use general (self.write_results) setting for writing force constants @@ -665,7 +693,7 @@ def write_dos( def calc_pdos( self, *, - mesh: MaybeList[float] = (10, 10, 10), + mesh: Optional[tuple[int, int, int]] = None, write_pdos: Optional[bool] = None, **kwargs, ) -> None: @@ -674,8 +702,8 @@ def calc_pdos( Parameters ---------- - mesh : MaybeList[float] - Mesh for sampling. Default is (10, 10, 10). + mesh : Optional[tuple[int, int, int]] + Mesh for sampling. Default is self.mesh. write_pdos : Optional[bool] Whether to write out results to file. Default is self.write_results. **kwargs @@ -684,6 +712,9 @@ def calc_pdos( if write_pdos is None: write_pdos = self.write_results + if mesh is None: + mesh = self.mesh + # Calculate phonons if not already in results if "phonon" not in self.results: # Use general (self.write_results) setting for writing force constants diff --git a/janus_core/cli/phonons.py b/janus_core/cli/phonons.py index 92cfa2a7..64fdb0ed 100644 --- a/janus_core/cli/phonons.py +++ b/janus_core/cli/phonons.py @@ -21,6 +21,7 @@ from janus_core.cli.utils import ( carbon_summary, check_config, + dict_tuples_to_lists, end_summary, parse_typer_dicts, save_struct_calc, @@ -39,30 +40,22 @@ def phonons( ctx: Context, struct: StructPath, supercell: Annotated[ - str, - Option(help="Supercell lattice vectors in the form '1x2x3'."), - ] = "2x2x2", + tuple[int, int, int], Option(help="Supercell lattice vectors.") + ] = (2, 2, 2), displacement: Annotated[ - float, - Option(help="Displacement for force constants calculation, in A."), + float, Option(help="Displacement for force constants calculation, in A.") ] = 0.01, + mesh: Annotated[ + tuple[int, int, int], Option(help="Mesh numbers along a, b, c axes.") + ] = (10, 10, 10), bands: Annotated[ bool, Option(help="Whether to compute band structure."), ] = False, - dos: Annotated[ - bool, - Option(help="Whether to calculate the DOS."), - ] = False, - pdos: Annotated[ - bool, - Option( - help="Whether to calculate the PDOS.", - ), - ] = False, + dos: Annotated[bool, Option(help="Whether to calculate the DOS.")] = False, + pdos: Annotated[bool, Option(help="Whether to calculate the PDOS.")] = False, thermal: Annotated[ - bool, - Option(help="Whether to calculate thermal properties."), + bool, Option(help="Whether to calculate thermal properties.") ] = False, temp_min: Annotated[ float, @@ -80,18 +73,14 @@ def phonons( bool, Option(help="Whether to symmetrize force constants.") ] = False, minimize: Annotated[ - bool, - Option( - help="Whether to minimize structure before calculations.", - ), + bool, Option(help="Whether to minimize structure before calculations.") ] = False, fmax: Annotated[ float, Option(help="Maximum force for optimization convergence.") ] = 0.1, minimize_kwargs: MinimizeKwargs = None, hdf5: Annotated[ - bool, - Option(help="Whether to save force constants in hdf5."), + bool, Option(help="Whether to save force constants in hdf5.") ] = True, plot_to_file: Annotated[ bool, @@ -133,11 +122,12 @@ def phonons( Typer (Click) Context. Automatically set. struct : Path Path of structure to simulate. - supercell : str - Supercell lattice vectors. Must be passed in the form '1x2x3'. Default is - 2x2x2. + supercell : tuple[int, int, int] + Supercell lattice vectors. Default is (2, 2, 2). displacement : float Displacement for force constants calculation, in A. Default is 0.01. + mesh : tuple[int, int, int] + Mesh for sampling. Default is (10, 10, 10). bands : bool Whether to calculate and save the band structure. Default is False. dos : bool @@ -209,17 +199,6 @@ def phonons( raise ValueError("'fmax' must be passed through the --fmax option") minimize_kwargs["fmax"] = fmax - try: - supercell = [int(x) for x in supercell.split("x")] - except ValueError as exc: - raise ValueError( - "Please pass lattice vectors as integers in the form 1x2x3" - ) from exc - - # Validate supercell list - if len(supercell) != 3: - raise ValueError("Please pass three lattice vectors in the form 1x2x3") - calcs = [] if bands: calcs.append("bands") @@ -247,6 +226,7 @@ def phonons( "calcs": calcs, "supercell": supercell, "displacement": displacement, + "mesh": mesh, "symmetrize": symmetrize, "minimize": minimize, "minimize_kwargs": minimize_kwargs, @@ -258,6 +238,7 @@ def phonons( "write_results": True, "write_full": write_full, "file_prefix": file_prefix, + "enable_progress_bar": True, } # Initialise phonons @@ -283,6 +264,9 @@ def phonons( log=log, ) + # Convert all tuples to list in inputs nested dictionary + dict_tuples_to_lists(inputs) + # Save summary information before calculations begin start_summary(command="phonons", summary=summary, inputs=inputs) diff --git a/janus_core/cli/singlepoint.py b/janus_core/cli/singlepoint.py index 3fa65877..2542ce0f 100644 --- a/janus_core/cli/singlepoint.py +++ b/janus_core/cli/singlepoint.py @@ -140,8 +140,8 @@ def singlepoint( ).absolute() log = s_point.log_kwargs["filename"] - # Store only filename as filemode is not set by user - inputs = {"log": log} + # Store inputs for yaml summary + inputs = singlepoint_kwargs.copy() # Add structure, MLIP information, and log to inputs save_struct_calc( diff --git a/janus_core/cli/utils.py b/janus_core/cli/utils.py index e87c24ff..e2619a71 100644 --- a/janus_core/cli/utils.py +++ b/janus_core/cli/utils.py @@ -36,6 +36,22 @@ def dict_paths_to_strs(dictionary: dict) -> None: dictionary[key] = str(value) +def dict_tuples_to_lists(dictionary: dict) -> None: + """ + Recursively iterate over dictionary, converting tuple values to lists. + + Parameters + ---------- + dictionary : dict + Dictionary to be converted. + """ + for key, value in dictionary.items(): + if isinstance(value, dict): + dict_paths_to_strs(value) + elif isinstance(value, tuple): + dictionary[key] = list(value) + + def dict_remove_hyphens(dictionary: dict) -> dict: """ Recursively iterate over dictionary, replacing hyphens with underscores in keys. diff --git a/janus_core/helpers/log.py b/janus_core/helpers/log.py index 507546ba..8f6aa390 100644 --- a/janus_core/helpers/log.py +++ b/janus_core/helpers/log.py @@ -186,6 +186,7 @@ def config_tracker( logging_logger=carbon_logger, project_name="janus-core", log_level=log_level, + allow_multiple_runs=True, ) # Suppress further logging from codecarbon diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index 6c61dd13..7e439ce5 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -6,11 +6,19 @@ from io import StringIO import logging from pathlib import Path -from typing import Any, Literal, Optional, TextIO, get_args +from typing import Any, Literal, Optional, TextIO, Union, get_args from ase import Atoms from ase.io import read, write from ase.io.formats import filetype +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + TextColumn, + TimeRemainingColumn, +) +from rich.style import Style from spglib import get_spacegroup from janus_core.helpers.janus_types import ( @@ -674,3 +682,44 @@ def _dump_csv( for cols in zip(*columns.values()): print(",".join(map(format, cols, formats)), file=file) + + +def track_progress(sequence: Union[Sequence, Iterable], description: str) -> Iterable: + """ + Track the progress of iterating over a sequence. + + This is done by displaying a progress bar in the console using the rich library. + The function is an iterator over the sequence, updating the progress bar each + iteration. + + Parameters + ---------- + sequence : Iterable + The sequence to iterate over. Must support "len". + description : str + The text to display to the left of the progress bar. + + Yields + ------ + Iterable + An iterable of the values in the sequence. + """ + text_column = TextColumn("{task.description}") + bar_column = BarColumn( + bar_width=None, + complete_style=Style(color="#FBBB10"), + finished_style=Style(color="#E38408"), + ) + completion_column = MofNCompleteColumn() + time_column = TimeRemainingColumn() + progress = Progress( + text_column, + bar_column, + completion_column, + time_column, + expand=True, + auto_refresh=False, + ) + + with progress: + yield from progress.track(sequence, description=description) diff --git a/pyproject.toml b/pyproject.toml index 9e23931f..ed65a0e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "janus-core" -version = "0.6.4" +version = "0.6.4.1" description = "Tools for machine learnt interatomic potentials" authors = [ "Elliott Kasoar", @@ -33,6 +33,7 @@ numpy = "^1.26.4" phonopy = "^2.23.1" python = "^3.9" pyyaml = "^6.0.1" +rich = "^13.9.1" seekpath = "^1.9.7" spglib = "^2.3.0" torch = ">= 2.1, <= 2.2" # Range required for dgl diff --git a/tests/test_descriptors.py b/tests/test_descriptors.py index 8818e84c..c1e1ddc6 100644 --- a/tests/test_descriptors.py +++ b/tests/test_descriptors.py @@ -88,3 +88,48 @@ def test_logging(tmp_path): assert log_file.exists() assert single_point.struct.info["emissions"] > 0 + + +def test_dispersion(): + """Test using mace_mp with dispersion.""" + single_point = SinglePoint( + struct_path=DATA_PATH / "NaCl.cif", + arch="mace_mp", + calc_kwargs={"dispersion": False}, + ) + + descriptors = Descriptors( + single_point.struct, + calc_per_element=True, + ) + descriptors.run() + + single_point_disp = SinglePoint( + struct_path=DATA_PATH / "NaCl.cif", + arch="mace_mp", + calc_kwargs={"dispersion": True}, + ) + + descriptors_disp = Descriptors( + single_point_disp.struct, + calc_per_element=True, + ) + descriptors_disp.run() + + assert ( + descriptors_disp.struct.info["mace_mp_descriptor"] + == descriptors.struct.info["mace_mp_descriptor"] + ) + + +def test_not_implemented_error(): + """Test correct error raised if descriptors not implemented.""" + single_point = SinglePoint( + struct_path=DATA_PATH / "NaCl.cif", + arch="chgnet", + ) + with pytest.raises(NotImplementedError): + Descriptors( + single_point.struct, + calc_per_element=True, + ) diff --git a/tests/test_geomopt_cli.py b/tests/test_geomopt_cli.py index de28f3b0..794385ce 100644 --- a/tests/test_geomopt_cli.py +++ b/tests/test_geomopt_cli.py @@ -85,8 +85,15 @@ def test_log(tmp_path): ) assert result.exit_code == 0 + # Only check reduced precision of energy and max force assert_log_contains( - log_path, includes="Starting geometry optimization", excludes="Using filter" + log_path, + includes=[ + "Starting geometry optimization", + "Final energy: -27.035127", + "Max force: ", + ], + excludes="Using filter", ) diff --git a/tests/test_md.py b/tests/test_md.py index ae068b67..ef815767 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -1058,3 +1058,30 @@ def test_auto_restart_restart_stem(tmp_path): final_traj = read(traj_path, index=":") assert len(final_traj) == 9 + + +def test_set_info(tmp_path): + """Test info is set at correct frequency.""" + file_prefix = tmp_path / "npt" + traj_path = tmp_path / "npt-traj.extxyz" + + single_point = SinglePoint( + struct_path=DATA_PATH / "NaCl.cif", + arch="mace", + calc_kwargs={"model": MODEL_PATH}, + ) + + npt = NPT( + struct=single_point.struct, + steps=10, + temp=1000, + stats_every=7, + file_prefix=file_prefix, + seed=2024, + traj_every=10, + ) + + npt.run() + final_struct = read(traj_path, index="-1") + assert npt.struct.info["density"] == pytest.approx(2.120952627887493) + assert final_struct.info["density"] == pytest.approx(2.120952627887493) diff --git a/tests/test_phonons_cli.py b/tests/test_phonons_cli.py index c4c9fb07..9175469d 100644 --- a/tests/test_phonons_cli.py +++ b/tests/test_phonons_cli.py @@ -228,7 +228,9 @@ def test_plot(tmp_path): "--struct", DATA_PATH / "NaCl.cif", "--supercell", - "1x1x1", + 1, + 1, + 1, "--pdos", "--dos", "--bands", @@ -268,7 +270,9 @@ def test_supercell(tmp_path): "--struct", DATA_PATH / "NaCl.cif", "--supercell", - "1x2x3", + 1, + 2, + 3, "--no-hdf5", "--file-prefix", file_prefix, @@ -285,10 +289,7 @@ def test_supercell(tmp_path): assert params["supercell_matrix"] == [[1, 0, 0], [0, 2, 0], [0, 0, 3]] -test_data = ["2", "2.1x2.1x2.1", "2x2xa"] - - -@pytest.mark.parametrize("supercell", test_data) +@pytest.mark.parametrize("supercell", [(2,), (2, 2), (2, 2, "a"), ("2x2x2",)]) def test_invalid_supercell(supercell, tmp_path): """Test errors are raise for invalid supercells.""" file_prefix = tmp_path / "test" @@ -300,13 +301,12 @@ def test_invalid_supercell(supercell, tmp_path): "--struct", DATA_PATH / "NaCl.cif", "--supercell", - supercell, + *supercell, "--file-prefix", file_prefix, ], ) - assert result.exit_code == 1 - assert isinstance(result.exception, ValueError) + assert result.exit_code == 1 or result.exit_code == 2 def test_minimize_kwargs(tmp_path): @@ -379,7 +379,9 @@ def test_valid_traj_input(read_kwargs, tmp_path): "--struct", DATA_PATH / "NaCl-traj.xyz", "--supercell", - "1x1x1", + 1, + 1, + 1, "--read-kwargs", read_kwargs, "--no-hdf5", diff --git a/tests/test_singlepoint_cli.py b/tests/test_singlepoint_cli.py index 044f18b3..6330ce3b 100644 --- a/tests/test_singlepoint_cli.py +++ b/tests/test_singlepoint_cli.py @@ -240,6 +240,7 @@ def test_summary(tmp_path): assert "inputs" in sp_summary assert "end_time" in sp_summary + assert "properties" in sp_summary["inputs"] assert "traj" in sp_summary["inputs"] assert "length" in sp_summary["inputs"]["traj"] assert "struct" in sp_summary["inputs"]["traj"]