diff --git a/mess/basis.py b/mess/basis.py index b9b4b0b..b032ec8 100644 --- a/mess/basis.py +++ b/mess/basis.py @@ -1,5 +1,6 @@ """basis sets of Gaussian type orbitals""" +from functools import cache from typing import Tuple import equinox as eqx @@ -12,7 +13,14 @@ from mess.orbital import Orbital, batch_orbitals from mess.primitive import Primitive from mess.structure import Structure -from mess.types import FloatN, FloatNx3, FloatNxM, FloatNxN, IntN, default_fptype +from mess.types import ( + FloatN, + FloatNx3, + FloatNxM, + FloatNxN, + IntN, + default_fptype, +) class Basis(eqx.Module): @@ -47,6 +55,7 @@ def fixer(x): df = pd.DataFrame() df["orbital"] = self.orbital_index + df["atom"] = self.primitives.atom_index df["coefficient"] = self.coefficients df["norm"] = self.primitives.norm df["center"] = fixer(self.primitives.center) @@ -95,6 +104,47 @@ def basisset(structure: Structure, basis_name: str = "sto-3g") -> Basis: Returns: Basis constructed from inputs """ + orbitals = [] + atom_index = [] + + for atom_id in range(structure.num_atoms): + element = int(structure.atomic_number[atom_id]) + out = _bse_to_orbitals(basis_name, element) + atom_index.extend([atom_id] * sum(len(ao.primitives) for ao in out)) + orbitals += out + + primitives, coefficients, orbital_index = batch_orbitals(orbitals) + primitives = eqx.tree_at(lambda p: p.atom_index, primitives, jnp.array(atom_index)) + center = structure.position[primitives.atom_index, :] + primitives = eqx.tree_at(lambda p: p.center, primitives, center) + + return Basis( + orbitals=orbitals, + structure=structure, + primitives=primitives, + coefficients=coefficients, + orbital_index=orbital_index, + basis_name=basis_name, + max_L=int(np.max(primitives.lmn)), + ) + + +@cache +def _bse_to_orbitals(basis_name: str, atomic_number: int) -> Tuple[Orbital]: + """ + Look up basis set parameters on the basis set exchange and build a tuple of Orbital. + + The output is cached to reuse the same objects for a given basis set and atomic + number. This can help save time when batching over different coordinates. + + Args: + basis_name (str): The name of the basis set to lookup on the basis set exchange. + atomic_number (int): The atomic number for the element to retrieve. + + Returns: + Tuple[Orbital]: Tuple of Orbital objects corresponding to the specified basis + set and atomic number. + """ from basis_set_exchange import get_basis from basis_set_exchange.sort import sort_basis @@ -110,38 +160,24 @@ def basisset(structure: Structure, basis_name: str = "sto-3g") -> Basis: bse_basis = get_basis( basis_name, - elements=structure.atomic_number.tolist(), + elements=atomic_number, uncontract_spdf=True, uncontract_general=True, ) bse_basis = sort_basis(bse_basis)["elements"] orbitals = [] - for a in range(structure.num_atoms): - center = structure.position[a, :] - shells = bse_basis[str(structure.atomic_number[a])]["electron_shells"] - - for s in shells: - for lmn in LMN_MAP[s["angular_momentum"][0]]: - ao = Orbital.from_bse( - center=center, - alphas=np.array(s["exponents"], dtype=default_fptype()), - lmn=np.array(lmn, dtype=np.int32), - coefficients=np.array(s["coefficients"], dtype=default_fptype()), - ) - orbitals.append(ao) - - primitives, coefficients, orbital_index = batch_orbitals(orbitals) - - return Basis( - orbitals=orbitals, - structure=structure, - primitives=primitives, - coefficients=coefficients, - orbital_index=orbital_index, - basis_name=basis_name, - max_L=int(np.max(primitives.lmn)), - ) + for s in bse_basis[str(atomic_number)]["electron_shells"]: + for lmn in LMN_MAP[s["angular_momentum"][0]]: + ao = Orbital.from_bse( + center=np.zeros(3, dtype=default_fptype()), + alphas=np.array(s["exponents"], dtype=default_fptype()), + lmn=np.array(lmn, dtype=np.int32), + coefficients=np.array(s["coefficients"], dtype=default_fptype()), + ) + orbitals.append(ao) + + return tuple(orbitals) def basis_iter(basis: Basis): diff --git a/mess/primitive.py b/mess/primitive.py index 1338091..86a7d39 100644 --- a/mess/primitive.py +++ b/mess/primitive.py @@ -16,6 +16,7 @@ class Primitive(eqx.Module): alpha: float = eqx.field(converter=asfparray, default=1.0) lmn: Int3 = eqx.field(converter=asintarray, default=(0, 0, 0)) norm: Optional[float] = None + atom_index: Optional[int] = None def __post_init__(self): if self.norm is None: diff --git a/test/test_benchmark.py b/test/test_benchmark.py index e3542da..c479aec 100644 --- a/test/test_benchmark.py +++ b/test/test_benchmark.py @@ -63,3 +63,12 @@ def harness(): return E.block_until_ready(), C.block_until_ready() benchmark(harness) + + +def test_construct_basis(benchmark): + def harness(): + mol = molecule("water") + basis = basisset(mol, "6-31g") + return jax.block_until_ready(basis) + + benchmark(harness)