Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Dec 11, 2024
1 parent 9b8670c commit 032cf9e
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions examples/torchpme/torchpme_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,29 @@
Learning Capabilities with torchpme
=======================================
:Authors: Egor Rumiantsev `@E-Rum <https://github.com/E-Rum/>`_;
Philip Loche `@PicoCentauri <https://github.com/PicoCentauri>`_
:Authors: Egor Rumiantsev `@E-Rum <https://github.com/E-Rum/>`_; Philip Loche
`@PicoCentauri <https://github.com/PicoCentauri>`_
This example demonstrates the capabilities of the `torchpme` package, focusing on
learning target charges and utilizing the :class:`CombinedPotential` class to
evaluate potentials that combine multiple pairwise interactions with optimizable ``weights``.
learning target charges and utilizing the :class:`CombinedPotential` class to evaluate
potentials that combine multiple pairwise interactions with optimizable ``weights``.
The ``weights`` are optimized to reproduce the energy of a system interacting purely
through Coulomb forces.
"""

import os

# %%
from typing import Dict

import ase.io
import ase.visualize.plot
import matplotlib.pyplot as plt
import mpltex
import numpy as np
import torch
from torchpme import CombinedPotential, EwaldCalculator, InversePowerLawPotential
from vesin import NeighborList


# %%
# Select computation device
device = "cpu"
Expand Down Expand Up @@ -84,6 +82,7 @@
atoms.get_potential_energy(), device=device, dtype=dtype
)


# %%
# Function to assign charges to atoms
def assign_charges(atoms, charge_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
Expand All @@ -96,6 +95,7 @@ def assign_charges(atoms, charge_dict: Dict[str, torch.Tensor]) -> torch.Tensor:

return charges.reshape(-1, 1)


# %%
# Define the energy computation
def compute_energy(charge_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
Expand All @@ -115,19 +115,18 @@ def compute_energy(charge_dict: Dict[str, torch.Tensor]) -> torch.Tensor:

return energy


# %%
# Define the loss function
def loss(charge_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Calculate the loss as the mean squared error between computed and reference energies."""
"""Calculate the loss as the mean squared error between computed and reference
energies."""
energy = compute_energy(charge_dict)
mse = torch.sum((energy - l_ref_energy) ** 2)

# Enforce charge neutrality as a penalty
total_charge = sum(charge_dict.values())
charge_penalty = total_charge**2

return mse.sum() # Optionally add charge_penalty for strict neutrality enforcement.


# %%
# Fit charge model

Expand Down Expand Up @@ -249,6 +248,7 @@ def loss(charge_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
"#000000", # Black
]


def plot_results(fname=None, show_snapshot=True):
"""
Plot the learning process for charges and kernel weights.
Expand Down Expand Up @@ -304,5 +304,6 @@ def plot_results(fname=None, show_snapshot=True):

plt.show()


# Call the plot function to visualize results
plot_results("figures/toy_model_learning.pdf", show_snapshot=True)
plot_results("toy_model_learning.pdf", show_snapshot=True)

0 comments on commit 032cf9e

Please sign in to comment.