diff --git a/src/nomad_simulations/general.py b/src/nomad_simulations/general.py index 45169b7e..0d5293ca 100644 --- a/src/nomad_simulations/general.py +++ b/src/nomad_simulations/general.py @@ -17,7 +17,7 @@ # import numpy as np -from typing import List +from typing import List, Callable, Any from structlog.stdlib import BoundLogger from nomad.units import ureg @@ -182,9 +182,23 @@ def _set_system_branch_depth( def resolve_composition_formula( self, system_parent: ModelSystem, logger: BoundLogger ) -> None: + """Determine and set the composition formula for system_parent and all of its + descendants. + + Args: + system_parent (ModelSystem): The upper-most level of the system hierarchy to consider. + logger (BoundLogger): The logger to log messages. """ - """ - def set_branch_composition(system: ModelSystem, subsystems: List[ModelSystem], atom_labels: List[str]) -> None: + def set_composition_formula(system: ModelSystem, subsystems: List[ModelSystem], atom_labels: List[str]) -> None: + """Determines the composition formula for `system` based on its `subsystems`. + If `system` has no children, the atom_labels are used to determine the formula. + + Args: + system (ModelSystem): The system under consideration. + subsystems (List[ModelSystem]): The children of system. + atom_labels (List[str]): The global list of atom labels corresponding + to the atom indices stored in system. + """ if not subsystems: atom_indices = system.atom_indices if system.atom_indices is not None else [] subsystem_labels = [np.array(atom_labels)[atom_indices]] if atom_labels else ['Unknown' for atom in range(len(atom_indices))] @@ -193,16 +207,24 @@ def set_branch_composition(system: ModelSystem, subsystems: List[ModelSystem], a if system.composition_formula is None: system.composition_formula = get_composition(subsystem_labels) - def traverse_system_recurs(system, atom_labels): + def get_composition_recurs(system: ModelSystem, atom_labels: List[str]) -> None: + """Traverse the system hierarchy downward and set the branch composition for + all (subs)systems at each level. + + Args: + system (ModelSystem): The system to traverse downward. + atom_labels (List[str]): The global list of atom labels corresponding + to the atom indices stored in system. + """ subsystems = system.model_system - set_branch_composition(system, subsystems, atom_labels) + set_composition_formula(system, subsystems, atom_labels) if subsystems: for subsystem in subsystems: - traverse_system_recurs(subsystem, atom_labels) + get_composition_recurs(subsystem, atom_labels) atoms_state = system_parent.cell[0].atoms_state if system_parent.cell is not None else [] atom_labels = [atom.chemical_symbol for atom in atoms_state] if atoms_state is not None else [] - traverse_system_recurs(system_parent, atom_labels) + get_composition_recurs(system_parent, atom_labels) def normalize(self, archive, logger) -> None: super(EntryData, self).normalize(archive, logger) @@ -226,5 +248,5 @@ def normalize(self, archive, logger) -> None: self._set_system_branch_depth(system_parent) if is_not_representative(system_parent, logger): - return + continue self.resolve_composition_formula(system_parent, logger) \ No newline at end of file diff --git a/src/nomad_simulations/utils/utils.py b/src/nomad_simulations/utils/utils.py index 68be8dfa..661ee46f 100644 --- a/src/nomad_simulations/utils/utils.py +++ b/src/nomad_simulations/utils/utils.py @@ -130,7 +130,7 @@ def is_not_representative(model_system, logger: BoundLogger): return True return False -# TODO Either update nomad.atomutils function and remove this one, or remove the one in atomutils if we prefer it here only +# TODO remove function in nomad.atomutils def get_composition(children_names: List[str]) -> str: """ Generates a generalized "chemical formula" based on the provided list `children_names`, diff --git a/tests/test_model_system.py b/tests/test_model_system.py index 60c8c0be..5f3d90a9 100644 --- a/tests/test_model_system.py +++ b/tests/test_model_system.py @@ -444,77 +444,6 @@ def test_normalize(self): assert np.isclose(model_system.elemental_composition[1].atomic_fraction, 1 / 3) - @pytest.mark.parametrize( - 'mol_label_list, n_mol_list, atom_labels_list, composition_formula_list', - [ - ( - ['H20'], - [3], - [['H', 'O', 'O']], - ['group_H20(1)', 'H20(3)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)'] - ), # pure system - ( - ['H20', 'Methane'], - [5, 2], - [['H', 'O', 'O'], ['C', 'H', 'H', 'H', 'H']], - ['group_H20(1)group_Methane(1)', 'H20(5)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)', 'H(1)O(2)', 'Methane(2)', 'C(1)H(4)', 'C(1)H(4)'] - ), # binary mixture - ], - ) - def test_system_hierarchy_for_molecules( - self, - mol_label_list: List[str], - n_mol_list: List[int], - atom_labels_list: List[str], - composition_formula_list: List[str] - ): - """ - Test the `ModelSystem` normalization of 'composition_formula' for atoms and molecules. - """ - #? Does it make sense to test the setting of branch_label or branch_depth? - model_system = ModelSystem(is_representative=True) - model_system.branch_label = 'Total System' - model_system.branch_depth = 0 - atomic_cell = AtomicCell() - model_system.cell.append(atomic_cell) - model_system.atom_indices = [] - for (mol_label, n_mol, atom_labels) in zip(mol_label_list, n_mol_list, atom_labels_list): - # Create a branch in the hierarchy for this molecule type - model_system_mol_group = ModelSystem(branch_label='group' + mol_label) - model_system_mol_group.atom_indices = [] - model_system_mol_group.branch_label = f"group_{mol_label}" - model_system_mol_group.branch_depth = 1 - model_system.model_system.append(model_system_mol_group) - for _ in range(n_mol): - # Create a branch in the hierarchy for this molecule - model_system_mol = ModelSystem(branch_label=mol_label) - model_system_mol.branch_label = mol_label - model_system_mol.branch_depth = 2 - model_system_mol_group.model_system.append(model_system_mol) - # add the corresponding atoms to the global atom list - for atom_label in atom_labels: - atomic_cell.atoms_state.append(AtomsState(chemical_symbol = atom_label)) - n_atoms = len(atomic_cell.atoms_state) - atom_indices = np.arange(n_atoms - len(atom_labels), n_atoms) - model_system_mol.atom_indices = atom_indices - model_system_mol_group.atom_indices = np.append(model_system_mol_group.atom_indices, atom_indices) - model_system.atom_indices = np.append(model_system.atom_indices, atom_indices) - - model_system.normalize(EntryArchive(), logger) - - assert model_system.composition_formula == composition_formula_list[0] - ctr_comp = 1 - def get_system_recurs(sec_system, ctr_comp): - for sys in sec_system: - assert sys.composition_formula == composition_formula_list[ctr_comp] - ctr_comp += 1 - sec_subsystem = sys.model_system - if sec_subsystem: - ctr_comp = get_system_recurs(sec_subsystem, ctr_comp) - return ctr_comp - - get_system_recurs(model_system.model_system, ctr_comp) - @pytest.mark.parametrize( 'is_representative, has_atom_indices, mol_label_list, n_mol_list, atom_labels_list, composition_formula_list, custom_formulas', [ @@ -593,27 +522,20 @@ def test_system_hierarchy_for_molecules( composition_formula_list: List[str], custom_formulas: List[str] ): - """ - Test the `ModelSystem` normalization of 'composition_formula' for atoms and molecules. + """Test the `ModelSystem` normalization of 'composition_formula' for atoms and molecules. - Description of test parameters: - is_representative: - Boolean specifying if branch_depth = 0 is representative or not. + Args: + is_representative (bool): Specifies if branch_depth = 0 is representative or not. If not representative, the composition formulas should not be generated. - has_atom_indices: - Boolean specifying if the atom_indices should be populated during parsing. + has_atom_indices (bool): Specifies if the atom_indices should be populated during parsing. Without atom_indices, the composition formulas for the deepest level of the hierarchy should not be populated. - mol_label_list: - List of molecule types for generating the hierarchy. - n_mol_list: List[int]: - List of the number of molecules for each molecule type. Should be same + mol_label_list (List[str]): Molecule types for generating the hierarchy. + n_mol_list (List[int]): Number of molecules for each molecule type. Should be same length as mol_label_list. - atom_labels_list: - List of atom labels for each molecule type. Should be same length as + atom_labels_list (List[str]): Atom labels for each molecule type. Should be same length as mol_label_list, with each entry being a list of corresponding atom labels. - composition_formula_list: - This is the list of resulting composition formulas after normalization. The + composition_formula_list (List[str]): Resulting composition formulas after normalization. The ordering is dictated by the recursive traversing of the hierarchy in get_system_recurs(), which follows each branch to its deepest level before moving to the next branch, i.e., [model_system.composition_formula, @@ -621,10 +543,12 @@ def test_system_hierarchy_for_molecules( model_system.model_system[0].model_system[0].composition_formula, model_system.model_system[0].model_system[1].composition_formula, ..., model_system.model_system[1].composition_formula, ...] - custom_formulas: - This is a list of custom composition formulas that can be set in the generation + custom_formulas (List[str]): Custom composition formulas that can be set in the generation of the hierarchy, which will cause the normalize to ignore (i.e., not overwrite) these formula entries. The ordering is as described above. + + Returns: + None """ ### Generate the system hierarchy ###