From 657926be56d94ab6c560bfe4c637485387829048 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Fri, 13 Dec 2024 15:01:53 +0100 Subject: [PATCH] Quickstart and Basic usage example (#123) * Add skeleton for basic usage example * Add basic usage notebook * Make example complete apart from ML model * Move notebook to example * update example * Improve descriptions with more details * Add quickstart * start proofread * Add more details on vesin --------- Co-authored-by: Kevin Kazuki Huguenin-Dumittan --- README.rst | 52 ++++ docs/src/index.rst | 1 + docs/src/installation.rst | 4 +- docs/src/references/potentials/potential.rst | 2 +- docs/src/references/potentials/spline.rst | 2 +- docs/src/references/utils/prefactors.rst | 5 +- examples/2-neighbor-lists-usage.py | 4 +- examples/basic-usage.py | 291 +++++++++++++++++++ src/torchpme/calculators/calculator.py | 2 +- tests/calculators/test_calculator.py | 2 +- 10 files changed, 355 insertions(+), 10 deletions(-) create mode 100644 examples/basic-usage.py diff --git a/README.rst b/README.rst index afac2b27..f19426f6 100644 --- a/README.rst +++ b/README.rst @@ -56,6 +56,58 @@ can optionally be installed together and used as ``torchpme.metatensor`` via pip install torch-pme[metatensor] +.. marker-quickstart + +Quickstart +---------- + +Here is a simple example get you started with *torch-pme*: + +.. code-block:: python + + import torch + import torchpme + + # Single charge in a cubic box + positions = torch.zeros((1, 3), requires_grad=True) + cell = 8 * torch.eye(3) + charges = torch.tensor([[1.0]]) + + # No neighbors for a single atom; use `vesin` for neighbors if needed + neighbor_indices = torch.zeros((0, 2), dtype=torch.int64) + neighbor_distances = torch.zeros((0,)) + + # Tune P3M parameters (cutoff optional, useful to set for ML with fixed cutoff) + smearing, p3m_parameters, _ = torchpme.utils.tune_p3m( + sum_squared_charges=1, + cell=cell, + positions=positions, + cutoff=5.0, + ) + + # Initialize potential and calculator + potential = torchpme.CoulombPotential(smearing) + calculator = torchpme.P3MCalculator(potential, **p3m_parameters) + + # Compute (per-atom) potentials + potentials = calculator.forward( + charges=charges, + cell=cell, + positions=positions, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, + ) + + # Calculate total energy and forces + energy = torch.sum(charges * potentials) + energy.backward() + forces = -positions.grad + + print("Energy:", energy.item()) + print("Forces:", forces) + +For more examples and details, please refer to the `documentation`_. + .. marker-issues Having problems or ideas? diff --git a/docs/src/index.rst b/docs/src/index.rst index d70caf3c..d8865901 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -27,6 +27,7 @@ Overview about installation + examples/basic-usage examples/index references/index contributing diff --git a/docs/src/installation.rst b/docs/src/installation.rst index 1c9bbf37..8bf2abcc 100644 --- a/docs/src/installation.rst +++ b/docs/src/installation.rst @@ -1,3 +1,5 @@ +.. _userdoc-installation: + .. include:: ../../README.rst :start-after: marker-installation - :end-before: marker-issues + :end-before: marker-quickstart diff --git a/docs/src/references/potentials/potential.rst b/docs/src/references/potentials/potential.rst index 1765801a..bdbd9c61 100644 --- a/docs/src/references/potentials/potential.rst +++ b/docs/src/references/potentials/potential.rst @@ -1,7 +1,7 @@ Potential ######### -.. autoclass:: torchpme.potentials.Potential +.. autoclass:: torchpme.Potential :members: .. minigallery:: torchpme.Potential diff --git a/docs/src/references/potentials/spline.rst b/docs/src/references/potentials/spline.rst index 8d1704d7..b862c0ba 100644 --- a/docs/src/references/potentials/spline.rst +++ b/docs/src/references/potentials/spline.rst @@ -1,7 +1,7 @@ SplinePotential ############### -.. autoclass:: torchpme.potentials.SplinePotential +.. autoclass:: torchpme.SplinePotential :members: .. minigallery:: torchpme.SplinePotential diff --git a/docs/src/references/utils/prefactors.rst b/docs/src/references/utils/prefactors.rst index f2641b61..de634ead 100644 --- a/docs/src/references/utils/prefactors.rst +++ b/docs/src/references/utils/prefactors.rst @@ -3,8 +3,9 @@ Prefactors ========== -This module contains common *prefactors* for converting between Gaussian units and other -unit systems. +This module contains common ``prefactors`` for converting between Gaussian units and +other unit systems that can be used to set a unit system when initializing a +:ref:`calculator `. Many scientific calculators and computational tools, such as ``torchpme``, are set to operate in Gaussian units by default, where the term :math:`1/(4 \pi \epsilon_0) = 1.0`, diff --git a/examples/2-neighbor-lists-usage.py b/examples/2-neighbor-lists-usage.py index d7fb3ba7..27b6eba0 100644 --- a/examples/2-neighbor-lists-usage.py +++ b/examples/2-neighbor-lists-usage.py @@ -70,9 +70,7 @@ # We now slightly displace the atoms from their initial positions randomly based on a # Gaussian distribution with a width of 0.1 Å to create non-zero forces. -rng = np.random.default_rng(42) -displacement = rng.normal(loc=0.0, scale=0.1, size=(len(atoms), 3)) -atoms.positions += displacement +atoms.rattle(stdev=0.1) chemiscope.show( frames=[atoms], diff --git a/examples/basic-usage.py b/examples/basic-usage.py new file mode 100644 index 00000000..05e3f1df --- /dev/null +++ b/examples/basic-usage.py @@ -0,0 +1,291 @@ +""" +Basic Usage +=========== + +.. currentmodule:: torchpme + +This example showcases how the main capabilities of ``torchpme``. +We build a simple ionic crystal and compute the electrostatic potential and the forces +for each atom. + +To follow this tutorial, it is assumed that torch-pme has been :ref:`installed +` on your computer. + +.. note:: + + Inside this tutorial and all other examples of the documentation you can click on + each object in a code block to be forwarded to the documentation of the class or + function to get further details. + +""" + +# %% + +import chemiscope +import torch +from ase import Atoms +from vesin.torch import NeighborList + +from torchpme import EwaldCalculator +from torchpme.potentials import CoulombPotential + +# %% +# +# We initially set the ``device`` and ``dtype`` for the calculations. We will use the +# CPU for this and double precision. If you have a CUDA devide you can also set the +# device to ``"cuda"`` to use the GPU and speed up the calculations. Double precision is +# a requirement for the neighbor list implementation that we are using here. + +device = "cpu" +dtype = torch.float64 + + +# %% +# +# Generate Simple Example Structures +# ---------------------------------- +# +# Throughout this tutorial, we will work with a simple atomic structure in three +# dimensions, which is a distorted version of the CsCl structure. The goal will be to +# compute the electrostatic potentials and the forces for each atom. +# +# We first generate a single unit cell of CsCl (2 atoms in the cell) where the Cs atoms +# get a charge of +1 and the Cl atom a charge of -1. + +atoms_unitcell = Atoms( + "CsCl", cell=torch.eye(3), positions=[[0, 0, 0], [0.5, 0.5, 0.5]] +) +atoms_unitcell.set_initial_charges([1, -1]) + +# %% +# +# We next generate a bigger structure by periodically repeating the unit cell 3 times in +# each direction and apply a small random distortion to all atoms. + +atoms = atoms_unitcell.repeat(3) +atoms.rattle(stdev=0.01) +atoms.wrap() # make sure all atoms are in the unit cell + +chemiscope.show( + frames=[atoms], + mode="structure", + settings=chemiscope.quick_settings(structure_settings={"unitCell": True}), +) + +# %% +# +# We now extract the required properties from the :class:`ase.Atoms` object and store +# them as individial variables as :class:`torch.Tensor`, which is the required input +# type for *torch-pme*. +# +# For the positions, we explicitly set `requires_grad=True`. This is because we are +# ultimately interested in computing the forces, which are the gradients of the total +# (electrostatic) energy with respect to the positions (up to a minus sign). +# *torch-pme* can automatically compute such gradients, for which we need to communicate +# at this point that we will need to take gradients with respect to the positions in +# the future. + +positions = torch.tensor( + atoms.positions, dtype=dtype, device=device, requires_grad=True +) +cell = torch.tensor(atoms.cell.array, dtype=dtype, device=device) +charges = torch.tensor( + atoms.get_initial_charges(), dtype=dtype, device=device +).unsqueeze(1) + +# %% +# +# Tuning: Find Optimal Hyperparameters +# ------------------------------------ +# +# Ewald and mesh methods require the specification of multiple hyperparameters, namely +# +# 1. the cutoff radius :math:`r_\mathrm{cut}`` for the short-range parts +# 2. the ``smearing`` parameter $\sigma$ determining the relative importance of the +# short-range and long-range terms in the split +# 3. either the mesh spacing :math:`h` for mesh-based methods, or a reciprocal space +# cutoff :math:`k_\mathrm{cut} = 2\pi/\lambda`` for the Ewald sum, where +# :math:`\lambda` is the shortest wavelength used in the Fourier series and +# corresponds to :math:`h` for mesh-based approaches +# +# For ML applications, we typically first select a short-range cutoff to be the same to +# conventional short-ranged ML models, and define the remaining parameters from there. +# In this example, we are simply computing the Coulomb potential, and thus compute the +# hyperparameters simply based on convergence criteria. + +box_length = cell[0, 0] +rcut = float(box_length) / 2 - 1e-10 +smearing = rcut / 5 +# lambda which gives the reciprocal space cutoff kcut=2*pi/lambda +lr_wavelength = smearing / 2 + +# %% +# +# However, especially for users without much experience on how to choose these +# hyperparameters, we have built-in tuning functions for each Calculator (see below) +# such as :func:`torchpme.utils.tune_ewald`, which can automatically find a good set +# of parameters. These can be used like this: +# +# .. code-block:: python +# +# sum_charges_sq = torch.sum(charges**2, dim=0) +# smearing, lr_wavelength, rcut = tune_ewald(sum_charges_sq, cell, positions, accuracy=1e-1) +# +# Define Potential +# ---------------- +# +# We now need to define the potential function with which the atoms interact. Since this +# is a library for long-range ML, we support three major options: +# +# 1. the :class:`Coulomb potential ` (:math:`1/r`) +# 2. more general :class:`inverse power-law potentials ` (:math:`1/r^p`) +# 3. an option to build custom potentials using :class:`splines ` +# +# This tutorial focuses on option (1), which is the most relevant from a practical point +# of view. We can simply initialize an instance of the :class:`CoulombPotential` class that +# contains all the necessary functions (such as those defining the short-range and +# long-range splits) for this potential and makes them useable in the rest of the code. + +potential = CoulombPotential(smearing=smearing) + +# %% +# +# Neighbor List +# ------------- +# +# *torch-pme* requires us to compute the neighbor list (NL) in advance. This is because +# for many ML applications, we would otherwise need to repeat this computation multiple +# times during model training of a neural network etc. By computing it externally and +# providing it as an input, we can streamline training workflows. +# We compute the neighbor list using the ``vesin`` package, which also provides a +# pytorch implementation that retains the computational graph. This will later allow us +# to automatically compute gradients of the distances with respect to the atomic +# coordinates. More details can be found in its documentation. + +nl = NeighborList(cutoff=rcut, full_list=False) +i, j, neighbor_distances = nl.compute( + points=positions, box=cell, periodic=True, quantities="ijd" +) +neighbor_indices = torch.stack([i, j], dim=1) + +# %% +# +# Main Part: Calculator +# --------------------- +# +# The ``Calculator`` classes are the main user-facing classes in *torch-pme*. These are +# used to compute atomic potentials :math:`V_i`` for a given set of positions and +# particle weights (charges). For periodic calculators that are the main focus of this +# tutorial, it is also required to specify a ``cell``. We have three periodic +# calculators: +# +# 1. :class:`EwaldCalculator`: uses the Ewald sum. Recommended for structures with +# :math:`<10000` atoms due to its high accuracy and simplicity. Should be avoided for +# big structures due to the slower :math:`\mathcal{O}(N^2)` to +# :math:`\mathcal{O}(N^{\frac{3}{2}})` scaling of the computational cost (depending +# on how the hyperparameters are chosen) with respect to the number of atoms. +# 2. :class:`PMECalculator`: uses the Particle Mesh Ewald (PME) method. Mostly for +# reference. +# 3. :class:`P3MCalculator`: uses the Particle-Particle Particle-Mesh (P3M) method. +# Recommended for structures :math:`\ge 10000` atoms due to the efficient +# :math:`\mathcal{O}(N\log N)` scaling of the computational cost with the number of +# atoms. +# +# Since our structure is relatively small, we use the :class:`EwaldCalculator`. +# We start by the initialization of the class. + +calculator = EwaldCalculator(potential=potential, lr_wavelength=lr_wavelength) + +# %% +# +# Compute Energy +# -------------- +# +# We have now all ingredients: we can use the ``Calculator`` class to, well, actually +# compute the potentials :math:`V_i` at the position of the atoms, or the total energy +# for the given particle weights (``charges``). The electrostatic potential can then be +# obtained as +# +# .. math:: +# +# E = \sum_{i=1}^N q_i V_i + +potentials = calculator.forward( + charges, cell, positions, neighbor_indices, neighbor_distances +) +energy = torch.sum(charges * potentials) + +print("Energy = \n", energy.item()) + +# %% +# +# .. hint:: +# +# The energy is in Gaussian units. You can change the unit system by setting the +# ``prefactor`` when initializing the calculator. For more details about this refer to +# :ref:`prefactors`. +# +# Compute Forces using backpropagation (automatic differentiation) +# ---------------------------------------------------------------- +# +# The forces on the particles can simply be obtained as minus the gradient of the energy +# with respect to the positions. These are easy to evaluate using the automatic +# differentiation capabilities of `pytorch +# `_ using the +# backpropagation method. +# +# Note that this only works since we set ``requires_grad=True`` when we initialized the +# positions tensor above. + +energy.backward() +forces = -positions.grad + +print("Force on first atom = \n", forces[0]) + +# %% +# +# Aperiodic Structures +# -------------------- +# +# For now, we have been using the :class:`EwaldCalculator` which is a periodic +# calculator. We can however also use it for aperiodic structures by just using it as a +# calculator with a cutoff radius. To start clone the positions to avoid accumulating +# gradients with respect to the same variables multiple times + +positions_aperiodic = positions.clone().detach() +positions_aperiodic.requires_grad = True + +# %% +# +# Compute neighbor list but this time without periodic boudary conditions + +i, j, neighbor_distances_aperiodic = nl.compute( + points=positions_aperiodic, box=cell, periodic=False, quantities="ijd" +) +neighbor_indices_aperiodic = torch.stack([i, j], dim=1) + +# %% +# +# Compute aperiodic potential using the dedicated subroutine that is present +# in all calculators. For now, we do not provide an explicit calculator for aperiodic +# systems since the focus in mainly on periodic ones. + +potentials_aperiodic = calculator._compute_rspace( + charges, neighbor_indices_aperiodic, neighbor_distances_aperiodic +) + +# %% +# +# Compute total energy and forces + +energy_aperiodic = torch.sum(charges * potentials_aperiodic) +energy_aperiodic.backward() +forces_aperiodic = positions_aperiodic.grad # forces on the particles + +# %% +# +# References to more advanced tasks +# --------------------------------- +# +# Refer to :ref:`examples ` for tackling specific and advanced tasks and +# to the :ref:`references ` for more information the spefific API. diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index 42f0a5c5..7f2da0e3 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -233,7 +233,7 @@ def _validate_compute_parameters( if charges.dtype != dtype: raise ValueError( - f"type of `charges` ({cell.dtype}) must be same as `positions` " + f"type of `charges` ({charges.dtype}) must be same as `positions` " f"({dtype})" ) diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py index 6c830508..5d585131 100644 --- a/tests/calculators/test_calculator.py +++ b/tests/calculators/test_calculator.py @@ -142,7 +142,7 @@ def test_invalid_shape_charges(): def test_invalid_dtype_charges(): calculator = CalculatorTest() match = ( - r"type of `charges` \(torch.float32\) must be same as `positions` " + r"type of `charges` \(torch.float64\) must be same as `positions` " r"\(torch.float32\)" ) with pytest.raises(ValueError, match=match):