From 48e98321a5991a434deacb7b3a7e317ec5254944 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 30 Sep 2024 09:32:28 -0600 Subject: [PATCH 1/4] Cache orbitals when constructing basis set --- mess/basis.py | 105 +++++++++++++++++++++++++++++++---------- test/test_benchmark.py | 8 ++++ 2 files changed, 88 insertions(+), 25 deletions(-) diff --git a/mess/basis.py b/mess/basis.py index b9b4b0b..7aa42c6 100644 --- a/mess/basis.py +++ b/mess/basis.py @@ -1,5 +1,6 @@ """basis sets of Gaussian type orbitals""" +from functools import cache from typing import Tuple import equinox as eqx @@ -12,7 +13,15 @@ from mess.orbital import Orbital, batch_orbitals from mess.primitive import Primitive from mess.structure import Structure -from mess.types import FloatN, FloatNx3, FloatNxM, FloatNxN, IntN, default_fptype +from mess.types import ( + Float3, + FloatN, + FloatNx3, + FloatNxM, + FloatNxN, + IntN, + default_fptype, +) class Basis(eqx.Module): @@ -95,6 +104,42 @@ def basisset(structure: Structure, basis_name: str = "sto-3g") -> Basis: Returns: Basis constructed from inputs """ + orbitals = [] + + for a in range(structure.num_atoms): + element = int(structure.atomic_number[a]) + center = structure.position[a, :] + orbitals += _build_orbitals(basis_name, element, center) + + primitives, coefficients, orbital_index = batch_orbitals(orbitals) + + return Basis( + orbitals=orbitals, + structure=structure, + primitives=primitives, + coefficients=coefficients, + orbital_index=orbital_index, + basis_name=basis_name, + max_L=int(np.max(primitives.lmn)), + ) + + +@cache +def _bse_to_orbitals(basis_name: str, atomic_number: int) -> Tuple[Orbital]: + """ + Look up basis set parameters on the basis set exchange and build a tuple of Orbital. + + The output is cached to reuse the same objects for a given basis set and atomic + number. This can help save time when batching over different coordinates. + + Args: + basis_name (str): The name of the basis set to lookup on the basis set exchange. + atomic_number (int): The atomic number for the element to retrieve. + + Returns: + Tuple[Orbital]: Tuple of Orbital objects corresponding to the specified basis + set and atomic number. + """ from basis_set_exchange import get_basis from basis_set_exchange.sort import sort_basis @@ -110,38 +155,48 @@ def basisset(structure: Structure, basis_name: str = "sto-3g") -> Basis: bse_basis = get_basis( basis_name, - elements=structure.atomic_number.tolist(), + elements=atomic_number, uncontract_spdf=True, uncontract_general=True, ) bse_basis = sort_basis(bse_basis)["elements"] orbitals = [] - for a in range(structure.num_atoms): - center = structure.position[a, :] - shells = bse_basis[str(structure.atomic_number[a])]["electron_shells"] - - for s in shells: - for lmn in LMN_MAP[s["angular_momentum"][0]]: - ao = Orbital.from_bse( - center=center, - alphas=np.array(s["exponents"], dtype=default_fptype()), - lmn=np.array(lmn, dtype=np.int32), - coefficients=np.array(s["coefficients"], dtype=default_fptype()), - ) - orbitals.append(ao) + for s in bse_basis[str(atomic_number)]["electron_shells"]: + for lmn in LMN_MAP[s["angular_momentum"][0]]: + ao = Orbital.from_bse( + center=np.zeros(3, dtype=default_fptype()), + alphas=np.array(s["exponents"], dtype=default_fptype()), + lmn=np.array(lmn, dtype=np.int32), + coefficients=np.array(s["coefficients"], dtype=default_fptype()), + ) + orbitals.append(ao) - primitives, coefficients, orbital_index = batch_orbitals(orbitals) + return tuple(orbitals) - return Basis( - orbitals=orbitals, - structure=structure, - primitives=primitives, - coefficients=coefficients, - orbital_index=orbital_index, - basis_name=basis_name, - max_L=int(np.max(primitives.lmn)), - ) + +def _build_orbitals( + basis_name: str, atomic_number: int, center: Float3 +) -> Tuple[Orbital]: + """ + Constructs a tuple of Orbital objects for a given atomic_number and basis set, + with each orbital centered at the specified coordinates. + + Args: + basis_name (str): The name of the basis set to use. + atomic_number (int): The atomic number used to build the orbitals. + center (Float3): the 3D coordinate specifying the center of the orbitals + + Returns: + Tuple[Orbital]: A tuple of Orbitals centered at the provided coordinates. + """ + orbitals = _bse_to_orbitals(basis_name, atomic_number) + + def where(orbitals): + return [p.center for ao in orbitals for p in ao.primitives] + + num_centers = len(where(orbitals)) + return eqx.tree_at(where, orbitals, replace=np.tile(center, (num_centers, 1))) def basis_iter(basis: Basis): diff --git a/test/test_benchmark.py b/test/test_benchmark.py index e3542da..3e35ac6 100644 --- a/test/test_benchmark.py +++ b/test/test_benchmark.py @@ -63,3 +63,11 @@ def harness(): return E.block_until_ready(), C.block_until_ready() benchmark(harness) + + +def test_construct_basis(benchmark): + def harness(): + mol = molecule("water") + basisset(mol, "6-31g") + + benchmark(harness) From f3ad64b152e9092b4bee970b7fb2a198c712b10d Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Tue, 1 Oct 2024 13:15:24 -0600 Subject: [PATCH 2/4] use block_until_ready on benchmark output --- test/test_benchmark.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_benchmark.py b/test/test_benchmark.py index 3e35ac6..c479aec 100644 --- a/test/test_benchmark.py +++ b/test/test_benchmark.py @@ -68,6 +68,7 @@ def harness(): def test_construct_basis(benchmark): def harness(): mol = molecule("water") - basisset(mol, "6-31g") + basis = basisset(mol, "6-31g") + return jax.block_until_ready(basis) benchmark(harness) From 48379e7cc71b1bd07f85df9964caa8ecf61076b9 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Tue, 1 Oct 2024 13:16:00 -0600 Subject: [PATCH 3/4] adding optional atom_index to primitives --- mess/primitive.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mess/primitive.py b/mess/primitive.py index 1338091..86a7d39 100644 --- a/mess/primitive.py +++ b/mess/primitive.py @@ -16,6 +16,7 @@ class Primitive(eqx.Module): alpha: float = eqx.field(converter=asfparray, default=1.0) lmn: Int3 = eqx.field(converter=asintarray, default=(0, 0, 0)) norm: Optional[float] = None + atom_index: Optional[int] = None def __post_init__(self): if self.norm is None: From 8551772f763bc1a3a8ff4ed0a8e34318ea022299 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Tue, 1 Oct 2024 13:38:39 -0600 Subject: [PATCH 4/4] adding atom_index field and simplified basis creation --- mess/basis.py | 39 ++++++++++----------------------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/mess/basis.py b/mess/basis.py index 7aa42c6..b032ec8 100644 --- a/mess/basis.py +++ b/mess/basis.py @@ -14,7 +14,6 @@ from mess.primitive import Primitive from mess.structure import Structure from mess.types import ( - Float3, FloatN, FloatNx3, FloatNxM, @@ -56,6 +55,7 @@ def fixer(x): df = pd.DataFrame() df["orbital"] = self.orbital_index + df["atom"] = self.primitives.atom_index df["coefficient"] = self.coefficients df["norm"] = self.primitives.norm df["center"] = fixer(self.primitives.center) @@ -105,13 +105,18 @@ def basisset(structure: Structure, basis_name: str = "sto-3g") -> Basis: Basis constructed from inputs """ orbitals = [] + atom_index = [] - for a in range(structure.num_atoms): - element = int(structure.atomic_number[a]) - center = structure.position[a, :] - orbitals += _build_orbitals(basis_name, element, center) + for atom_id in range(structure.num_atoms): + element = int(structure.atomic_number[atom_id]) + out = _bse_to_orbitals(basis_name, element) + atom_index.extend([atom_id] * sum(len(ao.primitives) for ao in out)) + orbitals += out primitives, coefficients, orbital_index = batch_orbitals(orbitals) + primitives = eqx.tree_at(lambda p: p.atom_index, primitives, jnp.array(atom_index)) + center = structure.position[primitives.atom_index, :] + primitives = eqx.tree_at(lambda p: p.center, primitives, center) return Basis( orbitals=orbitals, @@ -175,30 +180,6 @@ def _bse_to_orbitals(basis_name: str, atomic_number: int) -> Tuple[Orbital]: return tuple(orbitals) -def _build_orbitals( - basis_name: str, atomic_number: int, center: Float3 -) -> Tuple[Orbital]: - """ - Constructs a tuple of Orbital objects for a given atomic_number and basis set, - with each orbital centered at the specified coordinates. - - Args: - basis_name (str): The name of the basis set to use. - atomic_number (int): The atomic number used to build the orbitals. - center (Float3): the 3D coordinate specifying the center of the orbitals - - Returns: - Tuple[Orbital]: A tuple of Orbitals centered at the provided coordinates. - """ - orbitals = _bse_to_orbitals(basis_name, atomic_number) - - def where(orbitals): - return [p.center for ao in orbitals for p in ao.primitives] - - num_centers = len(where(orbitals)) - return eqx.tree_at(where, orbitals, replace=np.tile(center, (num_centers, 1))) - - def basis_iter(basis: Basis): from jax import tree