Skip to content

Commit

Permalink
merge develop
Browse files Browse the repository at this point in the history
  • Loading branch information
EBB2675 committed Oct 18, 2024
2 parents c94e995 + 20d2fb2 commit a007e1e
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 160 deletions.
4 changes: 2 additions & 2 deletions src/nomad_simulations/schema_packages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class NOMADSimulationsEntryPoint(SchemaPackageEntryPoint):
description='Limite of the number of atoms in the unit cell to be treated for the system type classification from MatID to work. This is done to avoid overhead of the package.',
)
equal_cell_positions_tolerance: float = Field(
1e-12,
description='Tolerance (in meters) for the cell positions to be considered equal.',
12,
description='Decimal order or tolerance (in meters) for comparing cell positions.',
)

def load(self):
Expand Down
209 changes: 137 additions & 72 deletions src/nomad_simulations/schema_packages/model_system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
#
# Copyright The NOMAD Authors.
#
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import re
from typing import TYPE_CHECKING, Optional
from functools import lru_cache
from hashlib import sha1
from typing import TYPE_CHECKING

import ase
import numpy as np
Expand All @@ -22,12 +42,17 @@
from nomad.units import ureg

if TYPE_CHECKING:
from collections.abc import Generator
from typing import Any, Callable, Optional

import pint
from nomad.datamodel.datamodel import EntryArchive
from nomad.metainfo import Context, Section
from structlog.stdlib import BoundLogger

from nomad_simulations.schema_packages.atoms_state import AtomsState
from nomad_simulations.schema_packages.utils import (
catch_not_implemented,
get_sibling_section,
is_not_representative,
)
Expand Down Expand Up @@ -200,6 +225,72 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
return


def _check_implemented(func: 'Callable'):
"""
Decorator to restrict the comparison functions to the same class.
"""

def wrapper(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return func(self, other)

return wrapper


class PartialOrderElement:
def __init__(self, representative_variable):
self.representative_variable = representative_variable

def __hash__(self):
return self.representative_variable.__hash__()

@_check_implemented
def __eq__(self, other):
return self.representative_variable == other.representative_variable

@_check_implemented
def __lt__(self, other):
return False

@_check_implemented
def __gt__(self, other):
return False

def __le__(self, other):
return self.__eq__(other)

def __ge__(self, other):
return self.__eq__(other)

# __ne__ assumes that usage in a finite set with its comparison definitions


class HashedPositions(PartialOrderElement):
# `representative_variable` is a `pint.Quantity` object

def __hash__(self):
hash_str = sha1(
np.ascontiguousarray(
np.round(
self.representative_variable.to_base_units().magnitude,
decimals=configuration.equal_cell_positions_tolerance,
out=None,
)
).tobytes()
).hexdigest()
return int(hash_str, 16)

def __eq__(self, other):
"""Equality as defined between HashedPositions."""
if (
self.representative_variable is None
or other.representative_variable is None
):
return NotImplemented
return np.allclose(self.representative_variable, other.representative_variable)


class Cell(GeometricSpace):
"""
A base section used to specify the cell quantities of a system at a given moment in time.
Expand All @@ -217,7 +308,7 @@ class Cell(GeometricSpace):
type=MEnum('original', 'primitive', 'conventional'),
description="""
Representation type of the cell structure. It might be:
- 'original' as in origanally parsed,
- 'original' as in originally parsed,
- 'primitive' as the primitive unit cell,
- 'conventional' as the conventional cell used for referencing.
""",
Expand Down Expand Up @@ -278,45 +369,36 @@ class Cell(GeometricSpace):
""",
)

def _check_positions(self, positions_1, positions_2) -> list:
# Check that all the `positions`` of `cell_1` match with the ones in `cell_2`
check_positions = []
for i1, pos1 in enumerate(positions_1):
for i2, pos2 in enumerate(positions_2):
if np.allclose(
pos1, pos2, atol=configuration.equal_cell_positions_tolerance
):
check_positions.append([i1, i2])
break
return check_positions

def is_equal_cell(self, other) -> bool:
"""
Check if the cell is equal to an`other` cell by comparing the `positions`.
Args:
other: The other cell to compare with.
Returns:
bool: True if the cells are equal, False otherwise.
"""
# TODO implement checks on `lattice_vectors` and other quantities to ensure the equality of primitive cells
if not isinstance(other, Cell):
return False
@staticmethod
def _generate_comparer(obj: 'Cell') -> 'Generator[Any, None, None]':
try:
return ((HashedPositions(pos)) for pos in obj.positions)
except AttributeError:
raise NotImplementedError

# If the `positions` are empty, return False
if self.positions is None or other.positions is None:
return False
@catch_not_implemented
def is_lt_cell(self, other) -> bool:
return set(self._generate_comparer(self)) < set(self._generate_comparer(other))

# The `positions` should have the same length (same number of positions)
if len(self.positions) != len(other.positions):
return False
n_positions = len(self.positions)
@catch_not_implemented
def is_gt_cell(self, other) -> bool:
return set(self._generate_comparer(self)) > set(self._generate_comparer(other))

check_positions = self._check_positions(
positions_1=self.positions, positions_2=other.positions
)
if len(check_positions) != n_positions:
return False
return True
@catch_not_implemented
def is_le_cell(self, other) -> bool:
return set(self._generate_comparer(self)) <= set(self._generate_comparer(other))

@catch_not_implemented
def is_ge_cell(self, other) -> bool:
return set(self._generate_comparer(self)) >= set(self._generate_comparer(other))

@catch_not_implemented
def is_equal_cell(self, other) -> bool: # TODO: improve naming
return set(self._generate_comparer(self)) == set(self._generate_comparer(other))

def is_ne_cell(self, other) -> bool:
# this does not hold in general, but here we use finite sets
return not self.is_equal_cell(other)

def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
super().normalize(archive, logger)
Expand Down Expand Up @@ -361,40 +443,20 @@ def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwarg
# Set the name of the section
self.name = self.m_def.name

def is_equal_cell(self, other) -> bool:
"""
Check if the atomic cell is equal to an`other` atomic cell by comparing the `positions` and
the `AtomsState[*].chemical_symbol`.
Args:
other: The other atomic cell to compare with.
Returns:
bool: True if the atomic cells are equal, False otherwise.
"""
if not isinstance(other, AtomicCell):
return False

# Compare positions using the parent sections's `__eq__` method
if not super().is_equal_cell(other=other):
return False

# Check that the `chemical_symbol` of the atoms in `cell_1` match with the ones in `cell_2`
check_positions = self._check_positions(
positions_1=self.positions, positions_2=other.positions
)
@staticmethod
def _generate_comparer(obj: 'AtomicCell') -> 'Generator[Any, None, None]':
# presumes `atoms_state` mapping 1-to-1 with `positions` and conserves the order
try:
for atom in check_positions:
element_1 = self.atoms_state[atom[0]].chemical_symbol
element_2 = other.atoms_state[atom[1]].chemical_symbol
if element_1 != element_2:
return False
except Exception:
return False
return True
return (
(HashedPositions(pos), PartialOrderElement(st.chemical_symbol))
for pos, st in zip(obj.positions, obj.atoms_state)
)
except AttributeError:
raise NotImplementedError

def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]:
"""
Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`.
Args:
logger (BoundLogger): The logger to log messages.
Expand All @@ -412,7 +474,7 @@ def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]:
chemical_symbols.append(atom_state.chemical_symbol)
return chemical_symbols

def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]:
def to_ase_atoms(self, logger: 'BoundLogger') -> 'Optional[ase.Atoms]':
"""
Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell`
section (labels, periodic_boundary_conditions, positions, and lattice_vectors).
Expand Down Expand Up @@ -602,8 +664,11 @@ class Symmetry(ArchiveSection):
)

def resolve_analyzed_atomic_cell(
self, symmetry_analyzer: SymmetryAnalyzer, cell_type: str, logger: 'BoundLogger'
) -> Optional[AtomicCell]:
self,
symmetry_analyzer: 'SymmetryAnalyzer',
cell_type: str,
logger: 'BoundLogger',
) -> 'Optional[AtomicCell]':
"""
Resolves the `AtomicCell` section from the `SymmetryAnalyzer` object and the cell_type
(primitive or conventional).
Expand Down Expand Up @@ -647,8 +712,8 @@ def resolve_analyzed_atomic_cell(
return atomic_cell

def resolve_bulk_symmetry(
self, original_atomic_cell: AtomicCell, logger: 'BoundLogger'
) -> tuple[Optional[AtomicCell], Optional[AtomicCell]]:
self, original_atomic_cell: 'AtomicCell', logger: 'BoundLogger'
) -> 'tuple[Optional[AtomicCell], Optional[AtomicCell]]':
"""
Resolves the symmetry of the material being simulated using MatID and the
originally parsed data under original_atomic_cell. It generates two other
Expand Down
1 change: 1 addition & 0 deletions src/nomad_simulations/schema_packages/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .utils import (
RussellSaundersState,
catch_not_implemented,
get_composition,
get_sibling_section,
get_variables,
Expand Down
18 changes: 17 additions & 1 deletion src/nomad_simulations/schema_packages/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from nomad.config import config

if TYPE_CHECKING:
from typing import Optional
from typing import Callable, Optional

from nomad.datamodel.data import ArchiveSection
from structlog.stdlib import BoundLogger
Expand Down Expand Up @@ -154,3 +154,19 @@ def get_composition(children_names: 'list[str]') -> str:
children_count_tup = np.unique(children_names, return_counts=True)
formula = ''.join([f'{name}({count})' for name, count in zip(*children_count_tup)])
return formula if formula else None


def catch_not_implemented(func: 'Callable') -> 'Callable':
"""
Decorator to default comparison functions outside the same class to `False`.
"""

def wrapper(self, other) -> bool:
if not isinstance(other, self.__class__):
return False # ? should this throw an error instead?
try:
return func(self, other)
except (TypeError, NotImplementedError):
return False

return wrapper
Loading

1 comment on commit a007e1e

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/nomad_simulations
   __init__.py4250%3–4
   _version.py11282%5–6
src/nomad_simulations/schema_packages
   __init__.py15287%39–41
   atoms_state.py1902189%13–15, 201–204, 228, 283–284, 352–353, 355, 537, 549–550, 611–615, 630–634, 641
   basis_set.py2502988%8–9, 122–133, 172–185, 244, 290, 471–475, 497–498, 542–545, 664, 695, 697
   general.py89891%4–7, 121, 185, 295–296, 306
   model_method.py2697871%10–12, 171–174, 177–184, 276–277, 297, 318–339, 355–381, 384–401, 587, 780, 791, 833–840, 878, 897, 977, 1034, 1109, 1223
   model_system.py3483789%45–51, 235, 254, 258, 261, 264, 290, 376–377, 454–455, 472–473, 686–689, 736–743, 917–918, 1139–1143, 1149–1150, 1158–1159, 1164, 1187
   numerical_settings.py2596176%12–14, 217, 219–220, 223–226, 230–231, 238–241, 250–253, 257–260, 262–265, 270–273, 279–282, 469–496, 571, 606–609, 633, 636, 681, 683–686, 690, 694, 741, 745–766, 821–822, 889
   outputs.py1201092%9–10, 252–255, 295–298, 323, 325, 362, 381
   physical_property.py102793%20–22, 202, 331–333
   variables.py861286%8–10, 98, 121, 145, 167, 189, 211, 233, 256, 276
src/nomad_simulations/schema_packages/properties
   band_gap.py51590%8–10, 135–136
   band_structure.py1232580%9–11, 232–265, 278, 285, 321–322, 325, 372–373, 378
   energies.py42979%7–9, 36, 57, 82, 103, 119, 134
   fermi_surface.py17476%7–9, 40
   forces.py22673%7–9, 36, 56, 79
   greens_function.py991387%7–9, 210–211, 214, 235–236, 239, 260–261, 264, 400
   hopping_matrix.py29583%7–9, 58, 94
   permittivity.py48883%7–9, 97–105
   spectral_profile.py26012851%9–11, 57–60, 95–98, 199–300, 356–368, 393–396, 416, 421–424, 466–502, 526, 573–576, 592–593, 598–604
   thermodynamics.py752764%7–9, 35, 56, 72, 81, 90, 101, 110, 137, 147, 157, 172–174, 177, 193, 213–215, 218, 234, 254–256, 259
src/nomad_simulations/schema_packages/utils
   utils.py791680%8–11, 65–74, 83–84, 89, 92, 169–170
TOTAL259951580% 

Tests Skipped Failures Errors Time
402 0 💤 0 ❌ 0 🔥 5.649s ⏱️

Please sign in to comment.