From f313d49238569f125476fa1ad242a0a0c22cb66f Mon Sep 17 00:00:00 2001 From: Miki Bonacci Date: Fri, 22 Nov 2024 21:30:42 +0100 Subject: [PATCH 1/2] Adding support for aiida-atomistic - in pseudo and cutoff we introduce the naming `LegacyStructureData` for th `orm.StructureData`. - two additional tests, triggered only when we have atomistic installed (`HA` is `True`). --- src/aiida_pseudo/groups/family/pseudo.py | 24 +++++++++---- src/aiida_pseudo/groups/mixins/cutoffs.py | 10 ++++-- tests/groups/family/test_pseudo.py | 43 ++++++++++++++++++++++- 3 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/aiida_pseudo/groups/family/pseudo.py b/src/aiida_pseudo/groups/family/pseudo.py index 4463668..3068a7f 100644 --- a/src/aiida_pseudo/groups/family/pseudo.py +++ b/src/aiida_pseudo/groups/family/pseudo.py @@ -5,13 +5,19 @@ from aiida.common import exceptions from aiida.common.lang import classproperty, type_check from aiida.orm import Group, QueryBuilder +from aiida.orm.nodes.data.structure import has_atomistic from aiida.plugins import DataFactory from aiida_pseudo.data.pseudo import PseudoPotentialData __all__ = ('PseudoPotentialFamily',) -StructureData = DataFactory('core.structure') +LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name + +# +HA = has_atomistic() +if HA: + StructureData = DataFactory('atomistic.structure') class PseudoPotentialFamily(Group): @@ -308,12 +314,12 @@ def get_pseudos( self, *, elements: Optional[Union[List[str], Tuple[str]]] = None, - structure: StructureData = None, - ) -> Mapping[str, StructureData]: + structure: Union[StructureData, LegacyStructureData] if HA else Union[LegacyStructureData] = None, + ) -> Mapping[str, Union[StructureData, LegacyStructureData] if HA else Union[LegacyStructureData]]: """Return the mapping of kind names on pseudo potential data nodes for the given list of elements or structure. :param elements: list of element symbols. - :param structure: the ``StructureData`` node. + :param structure: the ``StructureData`` or ``LegacyStructureData`` node. :return: dictionary mapping the kind names of a structure on the corresponding pseudo potential data nodes. :raises ValueError: if the family does not contain a pseudo for any of the elements of the given structure. """ @@ -323,11 +329,15 @@ def get_pseudos( if elements is None and structure is None: raise ValueError('have to specify one of the keyword arguments `elements` and `structure`.') - if elements is not None and not isinstance(elements, (list, tuple)) and not isinstance(elements, StructureData): + if elements is not None and not isinstance(elements, (list, tuple)): raise ValueError('elements should be a list or tuple of symbols.') - if structure is not None and not isinstance(structure, StructureData): - raise ValueError('structure should be a `StructureData` instance.') + if structure is not None and not ( + isinstance(structure, (LegacyStructureData, StructureData if HA else LegacyStructureData)) + ): + raise ValueError( + f'structure should be a `StructureData` or `LegacyStructureData` instance, not {type(structure)}.' + ) if structure is not None: return {kind.name: self.get_pseudo(kind.symbol) for kind in structure.kinds} diff --git a/src/aiida_pseudo/groups/mixins/cutoffs.py b/src/aiida_pseudo/groups/mixins/cutoffs.py index e19ba92..193de75 100644 --- a/src/aiida_pseudo/groups/mixins/cutoffs.py +++ b/src/aiida_pseudo/groups/mixins/cutoffs.py @@ -3,11 +3,17 @@ from typing import Optional from aiida.common.lang import type_check +from aiida.orm.nodes.data.structure import has_atomistic from aiida.plugins import DataFactory from aiida_pseudo.common.units import U -StructureData = DataFactory('core.structure') +LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name + +# +HA = has_atomistic() +if HA: + StructureData = DataFactory('atomistic.structure') __all__ = ('RecommendedCutoffMixin',) @@ -278,7 +284,7 @@ def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=N raise ValueError('at least one and only one of `elements` or `structure` should be defined') type_check(elements, (tuple, str), allow_none=True) - type_check(structure, StructureData, allow_none=True) + type_check(structure, (LegacyStructureData, StructureData) if HA else (LegacyStructureData), allow_none=True) if unit is not None: self.validate_cutoffs_unit(unit) diff --git a/tests/groups/family/test_pseudo.py b/tests/groups/family/test_pseudo.py index 9ee5276..694cad5 100644 --- a/tests/groups/family/test_pseudo.py +++ b/tests/groups/family/test_pseudo.py @@ -4,9 +4,12 @@ import pytest from aiida.common import exceptions from aiida.orm import QueryBuilder +from aiida.orm.nodes.data.structure import has_atomistic from aiida_pseudo.data.pseudo import PseudoPotentialData from aiida_pseudo.groups.family.pseudo import PseudoPotentialFamily +skip_atomistic = pytest.mark.skipif(not has_atomistic(), reason='Unable to import aiida-atomistic') + def test_type_string(): """Verify the `_type_string` class attribute is correctly set to the corresponding entry point name.""" @@ -408,7 +411,7 @@ def test_get_pseudos_raise(get_pseudo_family, generate_structure): with pytest.raises(ValueError, match='elements should be a list or tuple of symbols.'): family.get_pseudos(elements={'He', 'Ar'}) - with pytest.raises(ValueError, match='structure should be a `StructureData` instance.'): + with pytest.raises(ValueError, match='structure should be a `StructureData` or `LegacyStructureData` instance'): family.get_pseudos(structure={'He', 'Ar'}) with pytest.raises(ValueError, match=r'family `.*` does not contain pseudo for element `.*`'): @@ -454,3 +457,41 @@ def test_get_pseudos_structure_kinds(get_pseudo_family, generate_structure): assert isinstance(pseudos, dict) for element in elements: assert isinstance(pseudos[element], PseudoPotentialData) + + +@skip_atomistic +@pytest.mark.usefixtures('aiida_profile_clean') +def test_get_pseudos_atomsitic_structure(get_pseudo_family, generate_structure): + """ + Test the `PseudoPotentialFamily.get_pseudos` method when passing + an aiida-atomistic ``StructureData`` instance. + """ + + elements = ('Ar', 'He', 'Ne') + orm_structure = generate_structure(elements) + structure = orm_structure.to_atomistic() + family = get_pseudo_family(elements=elements) + + pseudos = family.get_pseudos(structure=structure) + assert isinstance(pseudos, dict) + for element in elements: + assert isinstance(pseudos[element], PseudoPotentialData) + + +@skip_atomistic +@pytest.mark.usefixtures('aiida_profile_clean') +def test_get_pseudos_atomistic_structure_kinds(get_pseudo_family, generate_structure): + """ + Test the `PseudoPotentialFamily.get_pseudos` for + an aiida-atomistic ``StructureData`` with kind names including digits. + """ + + elements = ('Ar1', 'Ar2') + orm_structure = generate_structure(elements) + structure = orm_structure.to_atomistic() + family = get_pseudo_family(elements=elements) + + pseudos = family.get_pseudos(structure=structure) + assert isinstance(pseudos, dict) + for element in elements: + assert isinstance(pseudos[element], PseudoPotentialData) From 511d02c04086d04717f65631e648e334a49d89a3 Mon Sep 17 00:00:00 2001 From: Miki Bonacci Date: Fri, 29 Nov 2024 12:09:45 +0100 Subject: [PATCH 2/2] Improving conditionals logic for has_atomistic try to load the entry point atomistic.structure, otherwise we fallback into LegacyStructureData (MissingEntryPointError). --- src/aiida_pseudo/groups/family/pseudo.py | 21 +++++++++------------ src/aiida_pseudo/groups/mixins/cutoffs.py | 12 +++++++----- tests/groups/family/test_pseudo.py | 15 ++++++++++++--- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/aiida_pseudo/groups/family/pseudo.py b/src/aiida_pseudo/groups/family/pseudo.py index 3068a7f..d2e4a6a 100644 --- a/src/aiida_pseudo/groups/family/pseudo.py +++ b/src/aiida_pseudo/groups/family/pseudo.py @@ -5,7 +5,6 @@ from aiida.common import exceptions from aiida.common.lang import classproperty, type_check from aiida.orm import Group, QueryBuilder -from aiida.orm.nodes.data.structure import has_atomistic from aiida.plugins import DataFactory from aiida_pseudo.data.pseudo import PseudoPotentialData @@ -14,10 +13,12 @@ LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name -# -HA = has_atomistic() -if HA: +try: StructureData = DataFactory('atomistic.structure') +except exceptions.MissingEntryPointError: + structures_classes = (LegacyStructureData,) +else: + structures_classes = (LegacyStructureData, StructureData) class PseudoPotentialFamily(Group): @@ -314,8 +315,8 @@ def get_pseudos( self, *, elements: Optional[Union[List[str], Tuple[str]]] = None, - structure: Union[StructureData, LegacyStructureData] if HA else Union[LegacyStructureData] = None, - ) -> Mapping[str, Union[StructureData, LegacyStructureData] if HA else Union[LegacyStructureData]]: + structure: Union[structures_classes] = None, + ) -> Mapping[str, Union[structures_classes]]: """Return the mapping of kind names on pseudo potential data nodes for the given list of elements or structure. :param elements: list of element symbols. @@ -332,12 +333,8 @@ def get_pseudos( if elements is not None and not isinstance(elements, (list, tuple)): raise ValueError('elements should be a list or tuple of symbols.') - if structure is not None and not ( - isinstance(structure, (LegacyStructureData, StructureData if HA else LegacyStructureData)) - ): - raise ValueError( - f'structure should be a `StructureData` or `LegacyStructureData` instance, not {type(structure)}.' - ) + if structure is not None and not (isinstance(structure, structures_classes)): + raise ValueError(f'structure is of type {type(structure)} but should be of: {structures_classes}') if structure is not None: return {kind.name: self.get_pseudo(kind.symbol) for kind in structure.kinds} diff --git a/src/aiida_pseudo/groups/mixins/cutoffs.py b/src/aiida_pseudo/groups/mixins/cutoffs.py index 193de75..3f37f57 100644 --- a/src/aiida_pseudo/groups/mixins/cutoffs.py +++ b/src/aiida_pseudo/groups/mixins/cutoffs.py @@ -2,18 +2,20 @@ import warnings from typing import Optional +from aiida.common.exceptions import MissingEntryPointError from aiida.common.lang import type_check -from aiida.orm.nodes.data.structure import has_atomistic from aiida.plugins import DataFactory from aiida_pseudo.common.units import U LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name -# -HA = has_atomistic() -if HA: +try: StructureData = DataFactory('atomistic.structure') +except MissingEntryPointError: + structures_classes = (LegacyStructureData,) +else: + structures_classes = (LegacyStructureData, StructureData) __all__ = ('RecommendedCutoffMixin',) @@ -284,7 +286,7 @@ def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=N raise ValueError('at least one and only one of `elements` or `structure` should be defined') type_check(elements, (tuple, str), allow_none=True) - type_check(structure, (LegacyStructureData, StructureData) if HA else (LegacyStructureData), allow_none=True) + type_check(structure, (structures_classes), allow_none=True) if unit is not None: self.validate_cutoffs_unit(unit) diff --git a/tests/groups/family/test_pseudo.py b/tests/groups/family/test_pseudo.py index 694cad5..d736944 100644 --- a/tests/groups/family/test_pseudo.py +++ b/tests/groups/family/test_pseudo.py @@ -4,11 +4,18 @@ import pytest from aiida.common import exceptions from aiida.orm import QueryBuilder -from aiida.orm.nodes.data.structure import has_atomistic +from aiida.plugins import DataFactory from aiida_pseudo.data.pseudo import PseudoPotentialData from aiida_pseudo.groups.family.pseudo import PseudoPotentialFamily -skip_atomistic = pytest.mark.skipif(not has_atomistic(), reason='Unable to import aiida-atomistic') +try: + DataFactory('atomistic.structure') +except exceptions.MissingEntryPointError: + has_atomistic = False +else: + has_atomistic = True + +skip_atomistic = pytest.mark.skipif(not has_atomistic, reason='Unable to import aiida-atomistic') def test_type_string(): @@ -411,7 +418,9 @@ def test_get_pseudos_raise(get_pseudo_family, generate_structure): with pytest.raises(ValueError, match='elements should be a list or tuple of symbols.'): family.get_pseudos(elements={'He', 'Ar'}) - with pytest.raises(ValueError, match='structure should be a `StructureData` or `LegacyStructureData` instance'): + with pytest.raises( + ValueError, match=r"but should be of: \(," + ): family.get_pseudos(structure={'He', 'Ar'}) with pytest.raises(ValueError, match=r'family `.*` does not contain pseudo for element `.*`'):