diff --git a/README.rst b/README.rst index 073185d0..1f226640 100644 --- a/README.rst +++ b/README.rst @@ -50,6 +50,58 @@ can optionally be installed together and used as ``torchpme.metatensor`` via pip install torch-pme[metatensor] +.. marker-quickstart + +Quickstart +---------- + +Here is a simple example get you started with *torch-pme*: + +.. code-block:: python + + import torch + import torchpme + + # Single charge in a cubic box + positions = torch.zeros((1, 3), requires_grad=True) + cell = 8 * torch.eye(3) + charges = torch.tensor([[1.0]]) + + # No neighbors for a single atom; use `vesin` for neighbors if needed + neighbor_indices = torch.zeros((0, 2), dtype=torch.int64) + neighbor_distances = torch.zeros((0,)) + + # Tune P3M parameters (cutoff optional, useful to set for ML with fixed cutoff) + smearing, p3m_parameters, _ = torchpme.utils.tune_p3m( + sum_squared_charges=1, + cell=cell, + positions=positions, + cutoff=5.0, + ) + + # Initialize potential and calculator + potential = torchpme.CoulombPotential(smearing) + calculator = torchpme.P3MCalculator(potential, **p3m_parameters) + + # Compute (per-atom) potentials + potentials = calculator.forward( + charges=charges, + cell=cell, + positions=positions, + neighbor_indices=neighbor_indices, + neighbor_distances=neighbor_distances, + ) + + # Calculate total energy and forces + energy = torch.sum(charges * potentials) + energy.backward() + forces = -positions.grad + + print("Energy:", energy.item()) + print("Forces:", forces) + +For more examples and details, please refer to the `documentation`_. + .. marker-issues Having problems or ideas? diff --git a/docs/src/installation.rst b/docs/src/installation.rst index ae92f24e..8bf2abcc 100644 --- a/docs/src/installation.rst +++ b/docs/src/installation.rst @@ -2,4 +2,4 @@ .. include:: ../../README.rst :start-after: marker-installation - :end-before: marker-issues + :end-before: marker-quickstart diff --git a/src/torchpme/calculators/calculator.py b/src/torchpme/calculators/calculator.py index 42f0a5c5..7f2da0e3 100644 --- a/src/torchpme/calculators/calculator.py +++ b/src/torchpme/calculators/calculator.py @@ -233,7 +233,7 @@ def _validate_compute_parameters( if charges.dtype != dtype: raise ValueError( - f"type of `charges` ({cell.dtype}) must be same as `positions` " + f"type of `charges` ({charges.dtype}) must be same as `positions` " f"({dtype})" ) diff --git a/tests/calculators/test_calculator.py b/tests/calculators/test_calculator.py index 6c830508..5d585131 100644 --- a/tests/calculators/test_calculator.py +++ b/tests/calculators/test_calculator.py @@ -142,7 +142,7 @@ def test_invalid_shape_charges(): def test_invalid_dtype_charges(): calculator = CalculatorTest() match = ( - r"type of `charges` \(torch.float32\) must be same as `positions` " + r"type of `charges` \(torch.float64\) must be same as `positions` " r"\(torch.float32\)" ) with pytest.raises(ValueError, match=match):