diff --git a/src/nomad_simulations/numerical_settings.py b/src/nomad_simulations/numerical_settings.py index 7f0bc0c7..de0912e1 100644 --- a/src/nomad_simulations/numerical_settings.py +++ b/src/nomad_simulations/numerical_settings.py @@ -160,14 +160,17 @@ def normalize(self, archive, logger) -> None: super().normalize(archive, logger) -class KMeshBase(Mesh): +class KSpaceFunctionalities: """ - A base section used for abstraction for `KMesh` and `KLinePath` sections. It contains the methods - `_check_reciprocal_lattice_vectors` and `resolve_high_symmetry_points` that are used in both sections. + A functionality class useful for defining methods shared between `KSpace`, `KMesh`, and `KLinePath`. """ def _check_reciprocal_lattice_vectors( - self, reciprocal_lattice_vectors: Optional[pint.Quantity], logger: BoundLogger + self, + reciprocal_lattice_vectors: Optional[pint.Quantity], + logger: BoundLogger, + check_grid: bool = False, + grid: Optional[List[int]] = [], ) -> bool: """ Check if the `reciprocal_lattice_vectors` exist and if they have the same dimensionality as `grid`. @@ -175,16 +178,25 @@ def _check_reciprocal_lattice_vectors( Args: reciprocal_lattice_vectors (Optional[pint.Quantity]): The reciprocal lattice vectors of the atomic cell. logger (BoundLogger): The logger to log messages. + check_grid (bool, optional): Flag to check the `grid` is set to True. Defaults to False. + grid (Optional[List[int]], optional): The grid of the `KMesh`. Defaults to []. Returns: - (bool): True if the `reciprocal_lattice_vectors` exist and have the same dimensionality as `grid`, False otherwise. + (bool): True if the `reciprocal_lattice_vectors` exist. If `check_grid_too` is set to True, it also checks if the + `reciprocal_lattice_vectors` and the `grid` have the same dimensionality. False otherwise. """ - if reciprocal_lattice_vectors is None or self.grid is None: - logger.warning( - 'Could not find `reciprocal_lattice_vectors` from parent `KSpace` or could not find `KMesh.grid`.' - ) + if reciprocal_lattice_vectors is None: + logger.warning('Could not find `reciprocal_lattice_vectors`.') return False - if len(reciprocal_lattice_vectors) != 3 or len(self.grid) != 3: + # Only checking the `reciprocal_lattice_vectors` + if not check_grid: + return True + + # Checking the `grid` too + if grid is None: + logger.warning('Could not find `grid`.') + return False + if len(reciprocal_lattice_vectors) != 3 or len(grid) != 3: logger.warning( 'The `reciprocal_lattice_vectors` and the `grid` should have the same dimensionality.' ) @@ -284,9 +296,6 @@ def resolve_high_symmetry_points( high_symmetry_points[key] = list(value) return high_symmetry_points - def normalize(self, archive, logger) -> None: - super().normalize(archive, logger) - class KMesh(Mesh): """ @@ -395,8 +404,11 @@ def get_k_line_density( (np.float64): The k-line density of the `KMesh`. """ # Initial check - if not self._check_reciprocal_lattice_vectors( - reciprocal_lattice_vectors, logger + if not KSpaceFunctionalities._check_reciprocal_lattice_vectors( + reciprocal_lattice_vectors=reciprocal_lattice_vectors, + logger=logger, + check_grid=True, + grid=self.grid, ): return None @@ -426,8 +438,11 @@ def resolve_k_line_density( (Optional[pint.Quantity]): The resolved `k_line_density` of the `KMesh`. """ # Initial check - if not self._check_reciprocal_lattice_vectors( - reciprocal_lattice_vectors, logger + if not KSpaceFunctionalities._check_reciprocal_lattice_vectors( + reciprocal_lattice_vectors=reciprocal_lattice_vectors, + logger=logger, + check_grid=True, + grid=self.grid, ): return None @@ -473,8 +488,10 @@ def normalize(self, archive, logger) -> None: # Resolve `high_symmetry_points` if self.high_symmetry_points is None: - self.high_symmetry_points = self.resolve_high_symmetry_points( - model_systems, logger + self.high_symmetry_points = ( + KSpaceFunctionalities.resolve_high_symmetry_points( + model_systems=model_systems, logger=logger + ) ) @@ -485,21 +502,6 @@ class KLinePath(ArchiveSection): value, one should multiply them by the reciprocal lattice vectors (`points_cartesian = points @ reciprocal_lattice_vectors`). """ - # high_symmetry_path = Quantity( - # type=JSON, - # shape=['*'], - # description=""" - # List of dictionaries containing the high-symmetry path (in units of the `reciprocal_lattice_vectors`) followed in - # the k-line path. E.g., in a cubic lattice: - # high_symmetry_path = [ - # {'Gamma': [0, 0, 0]}, - # {'X': [0.5, 0, 0]}, - # {'Y': [0, 0.5, 0]}, - # {'Gamma': [0, 0, 0]}, - # ] - # """, - # ) - high_symmetry_path_names = Quantity( type=str, shape=['*'], @@ -533,34 +535,6 @@ class KLinePath(ArchiveSection): """, ) - def _check_high_symmetry_path(self, logger: BoundLogger) -> bool: - """ - Check if the `high_symmetry_path_names` and `high_symmetry_path_values` are defined and have the same length. - - Args: - logger (BoundLogger): The logger to log messages. - - Returns: - (bool): True if the `high_symmetry_path_names` and `high_symmetry_path_values` are defined and have the same length, False otherwise. - """ - if ( - self.high_symmetry_path_names is None - or self.high_symmetry_path_values is None - ) or ( - len(self.high_symmetry_path_names) == 0 - or len(self.high_symmetry_path_values) == 0 - ): - logger.warning( - 'Could not find `KLinePath.high_symmetry_path_names` or `KLinePath.high_symmetry_path_values`.' - ) - return False - if len(self.high_symmetry_path_names) != len(self.high_symmetry_path_values): - logger.warning( - 'The length of `KLinePath.high_symmetry_path_names` and `KLinePath.high_symmetry_path_values` should coincide.' - ) - return False - return True - def resolve_high_symmetry_path_values( self, model_systems: List[ModelSystem], @@ -579,13 +553,15 @@ def resolve_high_symmetry_path_values( (Optional[List[float]]): The resolved `high_symmetry_path_values`. """ # Initial check on the `reciprocal_lattice_vectors` - if not self._check_reciprocal_lattice_vectors( - reciprocal_lattice_vectors, logger + if not KSpaceFunctionalities._check_reciprocal_lattice_vectors( + reciprocal_lattice_vectors=reciprocal_lattice_vectors, logger=logger ): return [] # Resolving the dictionary containing the `high_symmetry_points` for the given ModelSystem symmetry - high_symmetry_points = self.resolve_high_symmetry_points(model_systems, logger) + high_symmetry_points = KSpaceFunctionalities.resolve_high_symmetry_points( + model_systems=model_systems, logger=logger + ) if high_symmetry_points is None: return [] @@ -598,6 +574,34 @@ def resolve_high_symmetry_path_values( ] return high_symmetry_path_values + def _check_high_symmetry_path(self, logger: BoundLogger) -> bool: + """ + Check if the `high_symmetry_path_names` and `high_symmetry_path_values` are defined and have the same length. + + Args: + logger (BoundLogger): The logger to log messages. + + Returns: + (bool): True if the `high_symmetry_path_names` and `high_symmetry_path_values` are defined and have the same length, False otherwise. + """ + if ( + self.high_symmetry_path_names is None + or self.high_symmetry_path_values is None + ) or ( + len(self.high_symmetry_path_names) == 0 + or len(self.high_symmetry_path_values) == 0 + ): + logger.warning( + 'Could not find `KLinePath.high_symmetry_path_names` or `KLinePath.high_symmetry_path_values`.' + ) + return False + if len(self.high_symmetry_path_names) != len(self.high_symmetry_path_values): + logger.warning( + 'The length of `KLinePath.high_symmetry_path_names` and `KLinePath.high_symmetry_path_values` should coincide.' + ) + return False + return True + def get_high_symmetry_path_norms( self, reciprocal_lattice_vectors: Optional[pint.Quantity], @@ -759,6 +763,7 @@ class KSpace(NumericalSettings): depending on the k-space sampling: `k_mesh` or `k_line_path`. """ + # ! This needs to be normalized first in order to extract the `reciprocal_lattice_vectors` from the `ModelSystem.cell` information reciprocal_lattice_vectors = Quantity( type=np.float64, shape=[3, 3], diff --git a/tests/test_numerical_settings.py b/tests/test_numerical_settings.py index fcb4c8ff..4cc7150b 100644 --- a/tests/test_numerical_settings.py +++ b/tests/test_numerical_settings.py @@ -18,12 +18,12 @@ import pytest import numpy as np -from typing import Optional, List, Dict +from typing import Optional, List from nomad.units import ureg from nomad.datamodel import EntryArchive -from nomad_simulations.numerical_settings import KMesh, KLinePath +from nomad_simulations.numerical_settings import KMesh, KLinePath, KSpaceFunctionalities from . import logger from .conftest import generate_k_line_path, generate_k_space_simulation @@ -75,7 +75,64 @@ def test_normalize( assert k_space.reciprocal_lattice_vectors == result -# TODO add testing for KMesh +class TestKSpaceFunctionalities: + """ + Test the `KSpaceFunctionalities` class defined in `numerical_settings.py`. + """ + + @pytest.mark.parametrize( + 'reciprocal_lattice_vectors, check_grid, grid, result', + [ + (None, None, None, False), + ([[1, 0, 0], [0, 1, 0], [0, 0, 1]], False, None, True), + ([[1, 0, 0], [0, 1, 0], [0, 0, 1]], True, None, False), + ([[1, 0, 0], [0, 1, 0], [0, 0, 1]], True, [6, 6, 6, 4], False), + ([[1, 0, 0], [0, 1, 0], [0, 0, 1]], True, [6, 6, 6], True), + ], + ) + def test_check_reciprocal_lattice_vectors( + self, + reciprocal_lattice_vectors: Optional[List[List[float]]], + check_grid: bool, + grid: Optional[List[int]], + result: bool, + ): + """ + Test the `_check_reciprocal_lattice_vectors` private method. + """ + check = KSpaceFunctionalities._check_reciprocal_lattice_vectors( + reciprocal_lattice_vectors=reciprocal_lattice_vectors, + logger=logger, + check_grd=check_grid, + grid=grid, + ) + assert check == result + + def test_resolve_high_symmetry_points(self): + """ + Test the `resolve_high_symmetry_points` method. Only testing the valid situation in which the `ModelSystem` normalization worked. + """ + # `ModelSystem.normalize()` need to extract `bulk` as a type. + simulation = generate_k_space_simulation( + pbc=[True, True, True], + ) + model_systems = simulation.model_system + # normalize to extract symmetry + simulation.model_system[0].normalize(EntryArchive(), logger) + + # Testing the functionality method + high_symmetry_points = KSpaceFunctionalities.resolve_high_symmetry_points( + model_systems=model_systems, logger=logger + ) + assert len(high_symmetry_points) == 4 + assert high_symmetry_points == { + 'Gamma': [0, 0, 0], + 'M': [0.5, 0.5, 0], + 'R': [0.5, 0.5, 0.5], + 'X': [0, 0.5, 0], + } + + class TestKMesh: """ Test the `KMesh` class defined in `numerical_settings.py`. @@ -231,28 +288,6 @@ def test_resolve_k_line_density( else: assert k_line_density == result_k_line_density - def test_resolve_high_symmetry_points(self): - """ - Test the `resolve_high_symmetry_points` method. Only testing the valid situation in which the `ModelSystem` normalization worked. - """ - # `ModelSystem.normalize()` need to extract `bulk` as a type. - simulation = generate_k_space_simulation( - pbc=[True, True, True], - ) - model_system = simulation.model_system[0] - model_system.normalize(EntryArchive(), logger) # normalize to extract symmetry - k_mesh = simulation.model_method[0].numerical_settings[0].k_mesh[0] - high_symmetry_points = k_mesh.resolve_high_symmetry_points( - simulation.model_system, logger - ) - assert len(high_symmetry_points) == 4 - assert high_symmetry_points == { - 'Gamma': [0, 0, 0], - 'M': [0.5, 0.5, 0], - 'R': [0.5, 0.5, 0.5], - 'X': [0, 0.5, 0], - } - class TestKLinePath: """ @@ -284,6 +319,50 @@ def test_check_high_symmetry_path( ) assert k_line_path._check_high_symmetry_path(logger) == result + @pytest.mark.parametrize( + 'high_symmetry_path_names, result', + [ + (['Gamma', 'X', 'R'], [[0, 0, 0], [0, 0.5, 0], [0.5, 0.5, 0.5]]), + ], + ) + def test_resolve_high_symmetry_path_values( + self, + high_symmetry_path_names: List[str], + result: bool, + ): + """ + Test the `resolve_high_symmetry_path_values` method. Only testing the valid situation in which the `ModelSystem` normalization worked. + """ + # `ModelSystem.normalize()` need to extract `bulk` as a type. + simulation = generate_k_space_simulation( + pbc=[True, True, True], + high_symmetry_path_names=high_symmetry_path_names, + high_symmetry_path_values=None, + ) + model_system = simulation.model_system[0] + model_system.normalize(EntryArchive(), logger) # normalize to extract symmetry + # getting `reciprocal_lattice_vectors` + reciprocal_lattice_vectors = ( + simulation.model_method[0].numerical_settings[0].reciprocal_lattice_vectors + ) + + # `KLinePath` can be understood as a `KMeshBase` section + k_line_path = simulation.model_method[0].numerical_settings[0].k_line_path + k_line_path.normalize(EntryArchive(), logger) + high_symmetry_points_values = k_line_path.resolve_high_symmetry_path_values( + simulation.model_system, reciprocal_lattice_vectors, logger + ) + # high_symmetry_points = k_mesh_base.resolve_high_symmetry_points( + # simulation.model_system, logger + # ) + # assert len(high_symmetry_points) == 4 + # assert high_symmetry_points == { + # 'Gamma': [0, 0, 0], + # 'M': [0.5, 0.5, 0], + # 'R': [0.5, 0.5, 0.5], + # 'X': [0, 0.5, 0], + # } + def test_get_high_symmetry_path_norm(self, k_line_path: KLinePath): """ Test the `get_high_symmetry_path_norm` method.