From 6f7a7ccc254cbbe5d38110f455e0a588d1a35919 Mon Sep 17 00:00:00 2001 From: JosePizarro3 Date: Wed, 22 May 2024 13:12:07 +0200 Subject: [PATCH] Added functional programming for KLinePath.resolve_points Added testing for check_high_symmetry_path --- src/nomad_simulations/numerical_settings.py | 102 +++++++++----------- tests/test_numerical_settings.py | 27 +++++- 2 files changed, 71 insertions(+), 58 deletions(-) diff --git a/src/nomad_simulations/numerical_settings.py b/src/nomad_simulations/numerical_settings.py index b14e1e24..224f7694 100644 --- a/src/nomad_simulations/numerical_settings.py +++ b/src/nomad_simulations/numerical_settings.py @@ -18,7 +18,7 @@ import numpy as np import pint -import itertools +from itertools import accumulate, tee, chain from structlog.stdlib import BoundLogger from typing import Optional, List, Tuple, Union, Dict from ase.dft.kpoints import monkhorst_pack, get_monkhorst_pack_size_and_offset @@ -536,6 +536,9 @@ def _check_high_symmetry_path(self, logger: BoundLogger) -> bool: if ( self.high_symmetry_path_names is None or self.high_symmetry_path_values is None + ) or ( + len(self.high_symmetry_path_names) == 0 + or len(self.high_symmetry_path_values) == 0 ): logger.warning( 'Could not find `KLinePath.high_symmetry_path_names` or `KLinePath.high_symmetry_path_values`.' @@ -575,12 +578,12 @@ def get_high_symmetry_path_norms( return None rlv = reciprocal_lattice_vectors.magnitude - def calc_norms(value_rlv, prev_value_rlv): + def calc_norms( + value_rlv: np.ndarray, prev_value_rlv: np.ndarray + ) -> pint.Quantity: value_tot_rlv = value_rlv - prev_value_rlv return np.linalg.norm(value_tot_rlv) * reciprocal_lattice_vectors.u - from itertools import accumulate, tee - # Compute `rlv` projections rlv_projections = list( map(lambda value: value @ rlv, self.high_symmetry_path_values) @@ -589,8 +592,7 @@ def calc_norms(value_rlv, prev_value_rlv): # Create two iterators for the projections rlv_projections_1, rlv_projections_2 = tee(rlv_projections) - # Initialize the previous value iterators and skip the first element in the second iterator - prev_value_rlv = np.array([0, 0, 0]) + # Skip the first element in the second iterator next(rlv_projections_2, None) # Calculate the norms using accumulate @@ -601,29 +603,6 @@ def calc_norms(value_rlv, prev_value_rlv): ) return list(norms) - # # initializing the norms list (the first point has a norm of 0) - # high_symmetry_path_value_norms = [0.0 * reciprocal_lattice_vectors.u] - # # initializing the first point - # prev_value_norm = 0.0 * reciprocal_lattice_vectors.u - # prev_value_rlv = np.array([0, 0, 0]) - # for i, value in enumerate(self.high_symmetry_path_values): - # if i == 0: - # continue - # value_rlv = value @ rlv - # value_tot_rlv = value_rlv - prev_value_rlv - # value_norm = ( - # np.linalg.norm(value_tot_rlv) * reciprocal_lattice_vectors.u - # + prev_value_norm - # ) - - # # store in new path norms variable - # high_symmetry_path_value_norms.append(value_norm) - - # # accumulate value vector and norm - # prev_value_rlv = value_rlv - # prev_value_norm = value_norm - # return high_symmetry_path_value_norms - def resolve_points( self, points_norm: Union[np.ndarray, List[float]], @@ -647,6 +626,7 @@ def resolve_points( 'The `reciprocal_lattice_vectors` are not passed as an input.' ) return None + # Check if `points_norm` is a list and convert it to a numpy array if isinstance(points_norm, list): points_norm = np.array(points_norm) @@ -658,38 +638,46 @@ def resolve_points( ) self.n_line_points = len(points_norm) - # Calculate the total norm of the path in order to find the closest indices in the list of `points_norm` + # Calculate the norms in the path and find the closest indices in points_norm to the high symmetry path norms high_symmetry_path_value_norms = self.get_high_symmetry_path_norms( reciprocal_lattice_vectors, logger ) - closest_indices = [] - for i, norm in enumerate(high_symmetry_path_value_norms): - closest_idx = (np.abs(points_norm - norm.magnitude)).argmin() - closest_indices.append(closest_idx) - - # Append the data in the new `points` in units of the `reciprocal_lattice_vectors` - points = [] - for i, value in enumerate(self.high_symmetry_path_values): - if i == 0: - prev_value = value - prev_index = closest_indices[i] - continue - elif i == len(self.high_symmetry_path_values) - 1: - points.append( - np.linspace( - prev_value, value, num=closest_indices[i] - prev_index + 1 - ) - ) - else: - # pop the last element as it appears repeated in the next segment - points.append( - np.linspace( - prev_value, value, num=closest_indices[i] - prev_index + 1 - )[:-1] + closest_indices = list( + map( + lambda norm: (np.abs(points_norm - norm.magnitude)).argmin(), + high_symmetry_path_value_norms, + ) + ) + + def linspace_segments( + prev_value: np.ndarray, value: np.ndarray, num: int + ) -> np.ndarray: + return np.linspace(prev_value, value, num=num + 1)[:-1] + + # Generate point segments using `map` and `linspace_segments` + points_segments = list( + map( + lambda i, value: linspace_segments( + self.high_symmetry_path_values[i - 1], + value, + closest_indices[i] - closest_indices[i - 1], ) - prev_value = value - prev_index = closest_indices[i] - new_points = list(itertools.chain(*points)) + if i > 0 + else np.array([]), + range(len(self.high_symmetry_path_values)), + self.high_symmetry_path_values, + ) + ) + # and handle the last segment to include all points + points_segments[-1] = np.linspace( + self.high_symmetry_path_values[-2], + self.high_symmetry_path_values[-1], + num=closest_indices[-1] - closest_indices[-2] + 1, + ) + + # Flatten the list of segments into a single list of points + new_points = list(chain.from_iterable(points_segments)) + # And store this information in the `points` quantity if self.points is not None: logger.info('Overwriting `KLinePath.points` with the resolved points.') diff --git a/tests/test_numerical_settings.py b/tests/test_numerical_settings.py index d1ac2208..fcb4c8ff 100644 --- a/tests/test_numerical_settings.py +++ b/tests/test_numerical_settings.py @@ -26,7 +26,7 @@ from nomad_simulations.numerical_settings import KMesh, KLinePath from . import logger -from .conftest import generate_k_space_simulation +from .conftest import generate_k_line_path, generate_k_space_simulation class TestKSpace: @@ -259,6 +259,31 @@ class TestKLinePath: Test the `KLinePath` class defined in `numerical_settings.py`. """ + @pytest.mark.parametrize( + 'high_symmetry_path_names, high_symmetry_path_values, result', + [ + (None, None, False), + ([], [], False), + (['Gamma', 'X', 'Y'], None, False), + ([], [[0, 0, 0], [0.5, 0, 0], [0, 0.5, 0]], False), + (['Gamma', 'X', 'Y'], [[0, 0, 0], [0.5, 0, 0], [0, 0.5, 0]], True), + ], + ) + def test_check_high_symmetry_path( + self, + high_symmetry_path_names: List[str], + high_symmetry_path_values: List[List[float]], + result: bool, + ): + """ + Test the `_check_high_symmetry_path` private method. + """ + k_line_path = generate_k_line_path( + high_symmetry_path_names=high_symmetry_path_names, + high_symmetry_path_values=high_symmetry_path_values, + ) + assert k_line_path._check_high_symmetry_path(logger) == result + def test_get_high_symmetry_path_norm(self, k_line_path: KLinePath): """ Test the `get_high_symmetry_path_norm` method.