From 3c08a87ecdb3b3fe29b8b75343bd80bf837a6e93 Mon Sep 17 00:00:00 2001 From: ndaelman Date: Mon, 7 Oct 2024 18:46:32 +0200 Subject: [PATCH] Fix type annotation --- .../schema_packages/model_system.py | 20 +++++++++---------- .../schema_packages/utils/utils.py | 4 ++-- tests/test_model_system.py | 4 +++- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/nomad_simulations/schema_packages/model_system.py b/src/nomad_simulations/schema_packages/model_system.py index c8edc76b..ac3db44d 100644 --- a/src/nomad_simulations/schema_packages/model_system.py +++ b/src/nomad_simulations/schema_packages/model_system.py @@ -233,7 +233,7 @@ def __hash__(self): return self.representative_variable.__hash__() @staticmethod - def _check_implemented(func: Callable): + def _check_implemented(func: 'Callable'): """ Decorator to restrict the comparison functions to the same class. """ @@ -370,7 +370,7 @@ class Cell(GeometricSpace): ) @staticmethod - def _generate_comparer(obj) -> Generator[Any, None, None]: + def _generate_comparer(obj: 'Cell') -> 'Generator[Any, None, None]': try: return ((HashedPositions(pos)) for pos in obj.positions) except AttributeError: @@ -398,7 +398,7 @@ def is_equal_cell(self, other) -> bool: # TODO: improve naming def is_ne_cell(self, other) -> bool: # this does not hold in general, but here we use finite sets - return not self.is_equal_cell(other) + return (not self.is_equal_cell(other)) def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) @@ -444,9 +444,7 @@ def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwarg self.name = self.m_def.name @staticmethod - def _generate_comparer( - obj, - ) -> Generator[Any, None, None]: + def _generate_comparer(obj: 'AtomicCell') -> 'Generator[Any, None, None]': # presumes `atoms_state` mapping 1-to-1 with `positions` and conserves the order try: return ( @@ -456,7 +454,7 @@ def _generate_comparer( except AttributeError: raise NotImplementedError - def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]: + def to_ase_atoms(self, logger: 'BoundLogger') -> 'Optional[ase.Atoms]': """ Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell` section (labels, periodic_boundary_conditions, positions, and lattice_vectors). @@ -618,8 +616,8 @@ class Symmetry(ArchiveSection): ) def resolve_analyzed_atomic_cell( - self, symmetry_analyzer: SymmetryAnalyzer, cell_type: str, logger: 'BoundLogger' - ) -> Optional[AtomicCell]: + self, symmetry_analyzer: 'SymmetryAnalyzer', cell_type: str, logger: 'BoundLogger' + ) -> 'Optional[AtomicCell]': """ Resolves the `AtomicCell` section from the `SymmetryAnalyzer` object and the cell_type (primitive or conventional). @@ -663,8 +661,8 @@ def resolve_analyzed_atomic_cell( return atomic_cell def resolve_bulk_symmetry( - self, original_atomic_cell: AtomicCell, logger: 'BoundLogger' - ) -> tuple[Optional[AtomicCell], Optional[AtomicCell]]: + self, original_atomic_cell: 'AtomicCell', logger: 'BoundLogger' + ) -> 'tuple[Optional[AtomicCell], Optional[AtomicCell]]': """ Resolves the symmetry of the material being simulated using MatID and the originally parsed data under original_atomic_cell. It generates two other diff --git a/src/nomad_simulations/schema_packages/utils/utils.py b/src/nomad_simulations/schema_packages/utils/utils.py index f8b1b5c4..8e104d09 100644 --- a/src/nomad_simulations/schema_packages/utils/utils.py +++ b/src/nomad_simulations/schema_packages/utils/utils.py @@ -177,12 +177,12 @@ def get_composition(children_names: 'list[str]') -> str: return formula if formula else None -def catch_not_implemented(func: Callable): +def catch_not_implemented(func: 'Callable') -> 'Callable': """ Decorator to default comparison functions outside the same class to `False`. """ - def wrapper(self, other): + def wrapper(self, other) -> bool: if not isinstance(other, self.__class__): return False # ? should this throw an error instead? try: diff --git a/tests/test_model_system.py b/tests/test_model_system.py index 5cc7240a..0be82195 100644 --- a/tests/test_model_system.py +++ b/tests/test_model_system.py @@ -239,7 +239,9 @@ class TestAtomicCell: ), # different position-symbol map ], ) - def test_partial_order(self, cell_1: Cell, cell_2: Cell, result: dict[str, bool]): + def test_partial_order( + self, cell_1: 'Cell', cell_2: 'Cell', result: dict[str, bool] + ): """ Test the comparison operators of `Cell` and `AtomicCell`. """