diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py index 4aaf55f6..0428c0c6 100644 --- a/src/nomad_simulations/schema_packages/basis_set.py +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -247,11 +247,14 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: if self.n_primitive is not None: if self.exponents is not None and len(self.exponents) != self.n_primitive: raise ValueError( - f"Mismatch in number of exponents: expected {self.n_primitive}, found {len(self.exponents)}." + f'Mismatch in number of exponents: expected {self.n_primitive}, found {len(self.exponents)}.' ) - if self.contraction_coefficients is not None and len(self.contraction_coefficients) != self.n_primitive: + if ( + self.contraction_coefficients is not None + and len(self.contraction_coefficients) != self.n_primitive + ): raise ValueError( - f"Mismatch in number of contraction coefficients: expected {self.n_primitive}, found {len(self.contraction_coefficients)}." + f'Mismatch in number of contraction coefficients: expected {self.n_primitive}, found {len(self.contraction_coefficients)}.' ) diff --git a/src/nomad_simulations/schema_packages/numerical_settings.py b/src/nomad_simulations/schema_packages/numerical_settings.py index 63fb2ab4..37842f84 100644 --- a/src/nomad_simulations/schema_packages/numerical_settings.py +++ b/src/nomad_simulations/schema_packages/numerical_settings.py @@ -202,6 +202,12 @@ class NumericalIntegration(NumericalSettings): def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) + valid_coordinates = ['full', 'radial', 'angular', None] + if self.coordinate not in valid_coordinates: + logger.warning( + f'Invalid coordinate value: {self.coordinate}. Resetting to None.' + ) + self.coordinate = None class KSpaceFunctionalities: diff --git a/tests/test_basis_set.py b/tests/test_basis_set.py index e96c224b..eb2e20f0 100644 --- a/tests/test_basis_set.py +++ b/tests/test_basis_set.py @@ -500,6 +500,7 @@ def test_atom_centered_basis_set_normalize() -> None: # Add checks for normalized behavior, if any assert bs.basis_set == 'cc-pVTZ' + def test_atom_centered_basis_set_invalid_data() -> None: """Test behavior with missing or invalid data.""" bs = AtomCenteredBasisSet( @@ -522,8 +523,5 @@ def test_atom_centered_basis_set_invalid_data() -> None: bs.functional_composition = [invalid_function] # Call normalize to trigger validation - with pytest.raises(ValueError, match="Mismatch in number of exponents"): + with pytest.raises(ValueError, match='Mismatch in number of exponents'): invalid_function.normalize(None, logger) - - - diff --git a/tests/test_numerical_settings.py b/tests/test_numerical_settings.py index 6426b4cd..0ead803f 100644 --- a/tests/test_numerical_settings.py +++ b/tests/test_numerical_settings.py @@ -9,6 +9,8 @@ KLinePath, KMesh, KSpaceFunctionalities, + Mesh, + NumericalIntegration, ) from . import logger @@ -377,3 +379,82 @@ def test_resolve_points(self, k_line_path: KLinePath): ] ) assert np.allclose(k_line_path.points, points) + + +@pytest.mark.parametrize( + 'dimensionality, expected_warning', + [ + (3, None), # Valid case + (2, None), # Valid case + ( + 5, + '`dimensionality` meshes different than 1, 2, or 3 are not supported.', + ), # Invalid + ( + 0, + '`dimensionality` meshes different than 1, 2, or 3 are not supported.', + ), # Invalid + ], +) +def test_mesh_dimensionality_validation(dimensionality, expected_warning, caplog): + mesh = Mesh(dimensionality=dimensionality) + mesh.normalize(None, logger) + if expected_warning: + assert expected_warning in caplog.text + else: + assert caplog.text == '' + + +@pytest.mark.parametrize( + 'dimensionality, grid, points', + [ + (3, [10, 10, 10], None), # Valid grid, no points defined yet + (3, None, None), # Missing grid and points + ( + 3, + [10, 10, 10], + [[0, 0, 0], [1, 1, 1]], + ), # Valid points (though fewer than grid suggests) + ], +) +def test_mesh_grid_and_points(dimensionality, grid, points): + mesh = Mesh(dimensionality=dimensionality, grid=grid, points=points) + assert mesh.dimensionality == dimensionality + if grid is not None: + assert np.allclose(mesh.grid, grid) + else: + assert mesh.grid == grid + if points is not None: + assert np.allclose(mesh.points, points) + else: + assert mesh.points == points + + +def test_mesh_spacing_normalization(): + mesh = Mesh(dimensionality=3, grid=[10, 10, 10], spacing=[0.1, 0.1, 0.1]) + mesh.normalize(None, logger) + assert np.allclose(mesh.spacing, [0.1, 0.1, 0.1]) + + +def test_numerical_integration_mesh(): + mesh = Mesh(dimensionality=3, grid=[10, 10, 10]) + integration = NumericalIntegration(mesh=mesh) + assert integration.mesh.dimensionality == 3 + assert np.allclose(integration.mesh.grid, [10, 10, 10]) + + +@pytest.mark.parametrize( + 'integration_thresh, weight_cutoff', + [ + (1e-6, 1e-3), # Valid thresholds + (None, 1e-3), # Missing integration threshold + (1e-6, None), # Missing weight cutoff + (None, None), # Both thresholds missing + ], +) +def test_numerical_integration_thresholds(integration_thresh, weight_cutoff): + integration = NumericalIntegration( + integration_thresh=integration_thresh, weight_cutoff=weight_cutoff + ) + assert integration.integration_thresh == integration_thresh + assert integration.weight_cutoff == weight_cutoff