Skip to content

Commit

Permalink
Correct typing (mypy)
Browse files Browse the repository at this point in the history
  • Loading branch information
ndaelman committed May 13, 2024
1 parent 1539c15 commit 333f15c
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions src/nomad_simulations/model_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import ase
from ase import neighborlist as ase_nl
from ase.geometry import analysis as ase_as
from typing import Tuple, Optional
from typing import Optional, Callable
from structlog.stdlib import BoundLogger
import pint

from matid import SymmetryAnalyzer, Classifier # pylint: disable=import-error
from matid.classification.classifications import (
Expand All @@ -49,13 +50,13 @@
from .utils import get_sibling_section, is_not_representative


def check_attributes(attributes: list[str]) -> Optional[callable]:
def check_attributes(attributes: list[str]) -> Optional[Callable]:
"""
Check if the specified attributes are not None.
"""

def decorator(func: callable) -> callable:
def wrapper(self, *args, **kwargs) -> Optional[callable]:
def decorator(func: Callable) -> Callable:
def wrapper(self, *args, **kwargs) -> Optional[Callable]:
for attr in attributes:
if attr not in self.all_quantities: # ! verify
self.logger.error(
Expand Down Expand Up @@ -453,10 +454,10 @@ class DistributionHistogram:

def __init__(
self,
combo: tuple[str],
combo: list[str],
type: str, # ? replace for automated check?
distribution_values: np.ndarray = np.array([]),
bins: np.ndarray = np.array([]),
distribution_values: pint.Quantity,
bins: np.ndarray,
) -> None:
self.combo = combo
self.type = type
Expand All @@ -483,6 +484,7 @@ def produce_nomad_distribution(self) -> GeometryDistribution:
extremity_atom_labels=[self.combo[0], self.combo[-1]],
n_bins=len(self.bins),
frequency=self.frequency,
bins=self.bins,
)
if len(self.combo) > 2:
geom_dist.central_atom_labels = self.combo[1:-1]
Expand All @@ -497,7 +499,7 @@ class Distribution:

def __init__(
self,
combo: tuple[str],
combo: list[str],
type: str,
ase_atoms: ase.Atoms,
neighbor_list: ase_nl.NeighborList,
Expand Down Expand Up @@ -544,34 +546,34 @@ def __init__(
self._elements = sorted(list(set(self.ase_atoms.get_chemical_symbols())))

@property
def get_elemental_pairs(self) -> list[tuple[str, str]]:
def get_elemental_pairs(self) -> list[list[str]]:
"""Generate all possible pairs of element symbols.
Permutations are not considered, i.e., (B, A) is converted into (A, B)."""
return [
(atom_1, atom_2)
[atom_1, atom_2]
for i, atom_1 in enumerate(self._elements)
for atom_2 in self._elements[i:]
]

@property
def get_elemental_triples_centered(self) -> list[tuple[str, str, str]]:
def get_elemental_triples_centered(self) -> list[list[str]]:
"""Generate all possible triples of element symbols with the center element first.
Permutations between the outer elements are not considered, i.e., (A, C, B) is converted into (A, B, C).
"""
return [
(atom_1, atom_c, atom_2)
[atom_1, atom_c, atom_2]
for atom_c in self._elements
for i, atom_1 in enumerate(self._elements)
for atom_2 in self._elements[i:]
] # matching the order of the `ase` package: https://wiki.fysik.dtu.dk/ase/_modules/ase/geometry/analysis.html#Analysis.get_angles

@property
def get_elemental_quadruples_centered(self) -> list[tuple[str, str, str, str]]:
def get_elemental_quadruples_centered(self) -> list[list[str]]:
"""Generate all possible quadruples of element symbols with the center elements first.
Permutations between the outer elements are not considered, i.e., (A, D, B, C) is converted into (A, B, C, D).
"""
return [
(atom_1, atom_c, atom_d, atom_2)
[atom_1, atom_c, atom_d, atom_2]
for atom_c in self._elements
for atom_d in self._elements
for i, atom_1 in enumerate(self._elements)
Expand Down Expand Up @@ -700,10 +702,8 @@ def normalize(self, archive, logger: BoundLogger):
self.histogram_distributions, self.distributions = [], []
for distribution in self.simple_distributions:
if distribution.type == 'distances':
bins = np.arange(
np.arange(0, max(self.geometry_analysis_cutoffs), 0.001)
)
elif distribution.type == 'angles':
bins = np.arange(0, max(self.geometry_analysis_cutoffs), 0.001)
elif distribution.type == 'angles' or distribution.type == 'dihedrals':
bins = np.arange(0, 180, 0.01)
self.histogram_distributions.append(
distribution.produce_histogram(bins)
Expand Down Expand Up @@ -870,7 +870,7 @@ def resolve_analyzed_atomic_cell(

def resolve_bulk_symmetry(
self, original_atomic_cell: AtomicCell, logger: BoundLogger
) -> Tuple[Optional[AtomicCell], Optional[AtomicCell]]:
) -> tuple[Optional[AtomicCell], Optional[AtomicCell]]:
"""
Resolves the symmetry of the material being simulated using MatID and the
originally parsed data under original_atomic_cell. It generates two other
Expand Down Expand Up @@ -1234,7 +1234,7 @@ class ModelSystem(System):

def resolve_system_type_and_dimensionality(
self, ase_atoms: ase.Atoms, logger: BoundLogger
) -> Tuple[str, int]:
) -> tuple[str, int]:
"""
Resolves the `ModelSystem.type` and `ModelSystem.dimensionality` using `MatID` classification analyzer:
Expand Down

0 comments on commit 333f15c

Please sign in to comment.