Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache orbitals when constructing basis set #17

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 63 additions & 27 deletions mess/basis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""basis sets of Gaussian type orbitals"""

from functools import cache
from typing import Tuple

import equinox as eqx
Expand All @@ -12,7 +13,14 @@
from mess.orbital import Orbital, batch_orbitals
from mess.primitive import Primitive
from mess.structure import Structure
from mess.types import FloatN, FloatNx3, FloatNxM, FloatNxN, IntN, default_fptype
from mess.types import (
FloatN,
FloatNx3,
FloatNxM,
FloatNxN,
IntN,
default_fptype,
)


class Basis(eqx.Module):
Expand Down Expand Up @@ -47,6 +55,7 @@ def fixer(x):

df = pd.DataFrame()
df["orbital"] = self.orbital_index
df["atom"] = self.primitives.atom_index
df["coefficient"] = self.coefficients
df["norm"] = self.primitives.norm
df["center"] = fixer(self.primitives.center)
Expand Down Expand Up @@ -95,6 +104,47 @@ def basisset(structure: Structure, basis_name: str = "sto-3g") -> Basis:
Returns:
Basis constructed from inputs
"""
orbitals = []
atom_index = []

for atom_id in range(structure.num_atoms):
element = int(structure.atomic_number[atom_id])
out = _bse_to_orbitals(basis_name, element)
atom_index.extend([atom_id] * sum(len(ao.primitives) for ao in out))
orbitals += out

primitives, coefficients, orbital_index = batch_orbitals(orbitals)
primitives = eqx.tree_at(lambda p: p.atom_index, primitives, jnp.array(atom_index))
center = structure.position[primitives.atom_index, :]
primitives = eqx.tree_at(lambda p: p.center, primitives, center)

return Basis(
orbitals=orbitals,
structure=structure,
primitives=primitives,
coefficients=coefficients,
orbital_index=orbital_index,
basis_name=basis_name,
max_L=int(np.max(primitives.lmn)),
)


@cache
def _bse_to_orbitals(basis_name: str, atomic_number: int) -> Tuple[Orbital]:
"""
Look up basis set parameters on the basis set exchange and build a tuple of Orbital.

The output is cached to reuse the same objects for a given basis set and atomic
number. This can help save time when batching over different coordinates.

Args:
basis_name (str): The name of the basis set to lookup on the basis set exchange.
atomic_number (int): The atomic number for the element to retrieve.

Returns:
Tuple[Orbital]: Tuple of Orbital objects corresponding to the specified basis
set and atomic number.
"""
from basis_set_exchange import get_basis
from basis_set_exchange.sort import sort_basis

Expand All @@ -110,38 +160,24 @@ def basisset(structure: Structure, basis_name: str = "sto-3g") -> Basis:

bse_basis = get_basis(
basis_name,
elements=structure.atomic_number.tolist(),
elements=atomic_number,
uncontract_spdf=True,
uncontract_general=True,
)
bse_basis = sort_basis(bse_basis)["elements"]
orbitals = []

for a in range(structure.num_atoms):
center = structure.position[a, :]
shells = bse_basis[str(structure.atomic_number[a])]["electron_shells"]

for s in shells:
for lmn in LMN_MAP[s["angular_momentum"][0]]:
ao = Orbital.from_bse(
center=center,
alphas=np.array(s["exponents"], dtype=default_fptype()),
lmn=np.array(lmn, dtype=np.int32),
coefficients=np.array(s["coefficients"], dtype=default_fptype()),
)
orbitals.append(ao)

primitives, coefficients, orbital_index = batch_orbitals(orbitals)

return Basis(
orbitals=orbitals,
structure=structure,
primitives=primitives,
coefficients=coefficients,
orbital_index=orbital_index,
basis_name=basis_name,
max_L=int(np.max(primitives.lmn)),
)
for s in bse_basis[str(atomic_number)]["electron_shells"]:
for lmn in LMN_MAP[s["angular_momentum"][0]]:
ao = Orbital.from_bse(
center=np.zeros(3, dtype=default_fptype()),
alphas=np.array(s["exponents"], dtype=default_fptype()),
lmn=np.array(lmn, dtype=np.int32),
coefficients=np.array(s["coefficients"], dtype=default_fptype()),
)
orbitals.append(ao)

return tuple(orbitals)


def basis_iter(basis: Basis):
Expand Down
1 change: 1 addition & 0 deletions mess/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Primitive(eqx.Module):
alpha: float = eqx.field(converter=asfparray, default=1.0)
lmn: Int3 = eqx.field(converter=asintarray, default=(0, 0, 0))
norm: Optional[float] = None
atom_index: Optional[int] = None

def __post_init__(self):
if self.norm is None:
Expand Down
9 changes: 9 additions & 0 deletions test/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ def harness():
return E.block_until_ready(), C.block_until_ready()

benchmark(harness)


def test_construct_basis(benchmark):
def harness():
mol = molecule("water")
basis = basisset(mol, "6-31g")
return jax.block_until_ready(basis)

benchmark(harness)