Skip to content

Commit

Permalink
Added abstract class KMeshBase and functionalities in KLinePath
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed May 27, 2024
1 parent a9bdac2 commit 0113056
Showing 1 changed file with 180 additions and 118 deletions.
298 changes: 180 additions & 118 deletions src/nomad_simulations/numerical_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,134 @@ def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)


class KMeshBase(Mesh):
"""
A base section used for abstraction for `KMesh` and `KLinePath` sections. It contains the methods
`_check_reciprocal_lattice_vectors` and `resolve_high_symmetry_points` that are used in both sections.
"""

def _check_reciprocal_lattice_vectors(
self, reciprocal_lattice_vectors: Optional[pint.Quantity], logger: BoundLogger
) -> bool:
"""
Check if the `reciprocal_lattice_vectors` exist and if they have the same dimensionality as `grid`.
Args:
reciprocal_lattice_vectors (Optional[pint.Quantity]): The reciprocal lattice vectors of the atomic cell.
logger (BoundLogger): The logger to log messages.
Returns:
(bool): True if the `reciprocal_lattice_vectors` exist and have the same dimensionality as `grid`, False otherwise.
"""
if reciprocal_lattice_vectors is None or self.grid is None:
logger.warning(
'Could not find `reciprocal_lattice_vectors` from parent `KSpace` or could not find `KMesh.grid`.'
)
return False
if len(reciprocal_lattice_vectors) != 3 or len(self.grid) != 3:
logger.warning(
'The `reciprocal_lattice_vectors` and the `grid` should have the same dimensionality.'
)
return False
return True

def resolve_high_symmetry_points(
self,
model_systems: List[ModelSystem],
logger: BoundLogger,
eps: float = 3e-3,
) -> Optional[dict]:
"""
Resolves the `high_symmetry_points` from the list of `ModelSystem`. This method relies on using the `ModelSystem`
information in the sub-sections `Symmetry` and `AtomicCell`, and uses the ASE package to extract the
special (high symmetry) points information.
Args:
model_systems (List[ModelSystem]): The list of `ModelSystem` sections.
logger (BoundLogger): The logger to log messages.
eps (float, optional): Tolerance factor to define the `lattice` ASE object. Defaults to 3e-3.
Returns:
(Optional[dict]): The resolved `high_symmetry_points`.
"""
# Extracting `bravais_lattice` from `ModelSystem.symmetry` section and `ASE.cell` from `ModelSystem.cell`
lattice = None
for model_system in model_systems:
# General checks to proceed with normalization
if is_not_representative(model_system, logger):
continue
if model_system.symmetry is None:
logger.warning('Could not find `ModelSystem.symmetry`.')
continue
bravais_lattice = [symm.bravais_lattice for symm in model_system.symmetry]
if len(bravais_lattice) != 1:
logger.warning(
'Could not uniquely determine `bravais_lattice` from `ModelSystem.symmetry`.'
)
continue
bravais_lattice = bravais_lattice[0]

if model_system.cell is None:
logger.warning('Could not find `ModelSystem.cell`.')
continue
prim_atomic_cell = None
for atomic_cell in model_system.cell:
if atomic_cell.type == 'primitive':
prim_atomic_cell = atomic_cell
break
if prim_atomic_cell is None:
logger.warning(
'Could not find the primitive `AtomicCell` under `ModelSystem.cell`.'
)
continue
# function defined in AtomicCell
atoms = prim_atomic_cell.to_ase_atoms(logger)
cell = atoms.get_cell()
lattice = cell.get_bravais_lattice(eps)
break # only cover the first representative `ModelSystem`

# Checking if `bravais_lattice` and `lattice` are defined
if lattice is None:
logger.warning(
'Could not resolve `bravais_lattice` and `lattice` ASE object from the `ModelSystem`.'
)
return None

# Non-conventional ordering testing for certain lattices:
if bravais_lattice in ['oP', 'oF', 'oI', 'oS']:
a, b, c = lattice.a, lattice.b, lattice.c
assert a < b
if bravais_lattice != 'oS':
assert b < c
elif bravais_lattice in ['mP', 'mS']:
a, b, c = lattice.a, lattice.b, lattice.c
alpha = lattice.alpha * np.pi / 180
assert a <= c and b <= c # ordering of the conventional lattice
assert alpha < np.pi / 2

# Extracting the `high_symmetry_points` from the `lattice` object
special_points = lattice.get_special_points()
if special_points is None:
logger.warning(
'Could not find `lattice.get_special_points()` from the ASE package.'
)
return None
high_symmetry_points = {}
for key, value in lattice.get_special_points().items():
if key == 'G':
key = 'Gamma'
if bravais_lattice == 'tI':
if key == 'S':
key = 'Sigma'
elif key == 'S1':
key = 'Sigma1'
high_symmetry_points[key] = list(value)
return high_symmetry_points

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)


class KMesh(Mesh):
"""
A base section used to specify the settings of a sampling mesh in reciprocal space. The `points` and other
Expand Down Expand Up @@ -219,31 +347,6 @@ class KMesh(Mesh):

# TODO add extraction of `high_symmetry_points` using BandStructureNormalizer idea (left for later when defining outputs.py)

def _check_reciprocal_lattice_vectors(
self, reciprocal_lattice_vectors: Optional[pint.Quantity], logger: BoundLogger
) -> bool:
"""
Check if the `reciprocal_lattice_vectors` exist and if they have the same dimensionality as `grid`.
Args:
reciprocal_lattice_vectors (Optional[pint.Quantity]): The reciprocal lattice vectors of the atomic cell.
logger (BoundLogger): The logger to log messages.
Returns:
(bool): True if the `reciprocal_lattice_vectors` exist and have the same dimensionality as `grid`, False otherwise.
"""
if reciprocal_lattice_vectors is None or self.grid is None:
logger.warning(
'Could not find `reciprocal_lattice_vectors` from parent `KSpace` or could not find `KMesh.grid`.'
)
return False
if len(reciprocal_lattice_vectors) != 3 or len(self.grid) != 3:
logger.warning(
'The `reciprocal_lattice_vectors` and the `grid` should have the same dimensionality.'
)
return False
return True

def resolve_points_and_offset(
self, logger: BoundLogger
) -> Tuple[Optional[List[np.ndarray]], Optional[np.ndarray]]:
Expand Down Expand Up @@ -344,99 +447,6 @@ def resolve_k_line_density(
return k_line_density
return None

def resolve_high_symmetry_points(
self,
model_systems: List[ModelSystem],
logger: BoundLogger,
eps: float = 3e-3,
) -> Optional[dict]:
"""
Resolves the `high_symmetry_points` of the `KMesh` from the list of `ModelSystem`. This method
relies on using the `ModelSystem` information in the sub-sections `Symmetry` and `AtomicCell`, and uses
the ASE package to extract the special (high symmetry) points information.
Args:
model_systems (List[ModelSystem]): The list of `ModelSystem` sections.
logger (BoundLogger): The logger to log messages.
eps (float, optional): Tolerance factor to define the `lattice` ASE object. Defaults to 3e-3.
Returns:
(Optional[dict]): The resolved `high_symmetry_points` of the `KMesh`.
"""
# Extracting `bravais_lattice` from `ModelSystem.symmetry` section and `ASE.cell` from `ModelSystem.cell`
lattice = None
for model_system in model_systems:
# General checks to proceed with normalization
if is_not_representative(model_system, logger):
continue
if model_system.symmetry is None:
logger.warning('Could not find `ModelSystem.symmetry`.')
continue
bravais_lattice = [symm.bravais_lattice for symm in model_system.symmetry]
if len(bravais_lattice) != 1:
logger.warning(
'Could not uniquely determine `bravais_lattice` from `ModelSystem.symmetry`.'
)
continue
bravais_lattice = bravais_lattice[0]

if model_system.cell is None:
logger.warning('Could not find `ModelSystem.cell`.')
continue
prim_atomic_cell = None
for atomic_cell in model_system.cell:
if atomic_cell.type == 'primitive':
prim_atomic_cell = atomic_cell
break
if prim_atomic_cell is None:
logger.warning(
'Could not find the primitive `AtomicCell` under `ModelSystem.cell`.'
)
continue
# function defined in AtomicCell
atoms = prim_atomic_cell.to_ase_atoms(logger)
cell = atoms.get_cell()
lattice = cell.get_bravais_lattice(eps)
break # only cover the first representative `ModelSystem`

# Checking if `bravais_lattice` and `lattice` are defined
if lattice is None:
logger.warning(
'Could not resolve `bravais_lattice` and `lattice` ASE object from the `ModelSystem`.'
)
return None

# Non-conventional ordering testing for certain lattices:
if bravais_lattice in ['oP', 'oF', 'oI', 'oS']:
a, b, c = lattice.a, lattice.b, lattice.c
assert a < b
if bravais_lattice != 'oS':
assert b < c
elif bravais_lattice in ['mP', 'mS']:
a, b, c = lattice.a, lattice.b, lattice.c
alpha = lattice.alpha * np.pi / 180
assert a <= c and b <= c # ordering of the conventional lattice
assert alpha < np.pi / 2

# Extracting the `high_symmetry_points` from the `lattice` object
special_points = lattice.get_special_points()
if special_points is None:
logger.warning(
'Could not find `lattice.get_special_points()` from the ASE package.'
)
return None
high_symmetry_points = {}
for key, value in lattice.get_special_points().items():
if key == 'G':
key = 'Gamma'
if bravais_lattice == 'tI':
if key == 'S':
key = 'Sigma'
elif key == 'S1':
key = 'Sigma1'
high_symmetry_points[key] = list(value)
return high_symmetry_points

def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

Expand Down Expand Up @@ -551,6 +561,43 @@ def _check_high_symmetry_path(self, logger: BoundLogger) -> bool:
return False
return True

def resolve_high_symmetry_path_values(
self,
model_systems: List[ModelSystem],
reciprocal_lattice_vectors: pint.Quantity,
logger: BoundLogger,
) -> Optional[List[float]]:
"""
Resolves the `high_symmetry_path_values` of the `KLinePath` from the `high_symmetry_path_names`.
Args:
model_systems (List[ModelSystem]): The list of `ModelSystem` sections.
reciprocal_lattice_vectors (pint.Quantity): The reciprocal lattice vectors of the atomic cell.
logger (BoundLogger): The logger to log messages.
Returns:
(Optional[List[float]]): The resolved `high_symmetry_path_values`.
"""
# Initial check on the `reciprocal_lattice_vectors`
if not self._check_reciprocal_lattice_vectors(
reciprocal_lattice_vectors, logger
):
return []

# Resolving the dictionary containing the `high_symmetry_points` for the given ModelSystem symmetry
high_symmetry_points = self.resolve_high_symmetry_points(model_systems, logger)
if high_symmetry_points is None:
return []

# Appending into a list which is stored in the `high_symmetry_path_values`. There is a check in the `normalize()`
# function to ensure that the length of the `high_symmetry_path_names` and `high_symmetry_path_values` coincide.
high_symmetry_path_values = [
high_symmetry_points[name]
for name in self.high_symmetry_path_names
if name in high_symmetry_points.keys()
]
return high_symmetry_path_values

def get_high_symmetry_path_norms(
self,
reciprocal_lattice_vectors: Optional[pint.Quantity],
Expand Down Expand Up @@ -686,6 +733,21 @@ def linspace_segments(
def normalize(self, archive, logger) -> None:
super().normalize(archive, logger)

# Resolves `high_symmetry_path_values` from `high_symmetry_path_names`
model_systems = self.m_xpath(
'm_parent.m_parent.m_parent.model_system', dict=False
)
reciprocal_lattice_vectors = self.m_xpath(
'm_parent.reciprocal_lattice_vectors', dict=False
)
if (
self.high_symmetry_path_values is None
or len(self.high_symmetry_path_values) == 0
):
self.high_symmetry_path_values = self.resolve_high_symmetry_path_values(
model_systems, reciprocal_lattice_vectors, logger
)

# If `high_symmetry_path` is not defined, we do not normalize the KLinePath
if not self._check_high_symmetry_path(logger):
return
Expand Down

0 comments on commit 0113056

Please sign in to comment.