Skip to content

Commit

Permalink
Merge pull request #172 from NLESC-JCER/issue170
Browse files Browse the repository at this point in the history
WIP : Fix Nan Issue
  • Loading branch information
NicoRenaud authored Dec 5, 2024
2 parents a1b4e11 + 91cfde4 commit f4a5fb0
Show file tree
Hide file tree
Showing 37 changed files with 87 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
fail-fast: false
matrix:
version: [3.8]
version: ['3.8', '3.10']

steps:
- name: Cancel Previous Runs
Expand Down
23 changes: 0 additions & 23 deletions .github/workflows/draft-pdf.yml

This file was deleted.

2 changes: 1 addition & 1 deletion docs/example/scf/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
calc = ['pyscf', # pyscf
'adf', # adf 2019
'adf2019' # adf 2020+
][1]
][0]

# select an appropriate basis
basis = {
Expand Down
16 changes: 15 additions & 1 deletion qmctorch/utils/algebra_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

import numpy as np
from scipy.special import factorial2 as f2

def btrace(M):
"""Computes the trace of batched matrices
Expand Down Expand Up @@ -43,6 +44,19 @@ def bdet2(M):
return M[..., 0, 0] * M[..., 1, 1] - M[..., 0, 1] * M[..., 1, 0]


def double_factorial(input):
"""Computes the double factorial of an array of int
Args:
input (List): input numbers
Returns:
List: values of the double factorial
"""
output = f2(input)
return np.array([1 if o==0 else o for o in output])


class BatchDeterminant(torch.autograd.Function):

@staticmethod
Expand Down
8 changes: 4 additions & 4 deletions qmctorch/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

def set_torch_double_precision():
"""Set the default precision to double for all torch tensors."""
torch.set_default_dtype = torch.float64
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.float64)
# torch.set_default_tensor_type(torch.DoubleTensor)


def set_torch_single_precision():
"""Set the default precision to single for all torch tensors."""
torch.set_default_dtype = torch.float32
torch.set_default_tensor_type(torch.FloatTensor)
torch.set_default_dtype(torch.float32)
# torch.set_default_tensor_type(torch.FloatTensor)


def fast_power(x, k, mask0=None, mask2=None):
Expand Down
28 changes: 10 additions & 18 deletions qmctorch/wavefunction/orbitals/norm_orbital.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np

from ...utils.algebra_utils import double_factorial

def atomic_orbital_norm(basis):
"""Computes the norm of the atomic orbitals
Expand Down Expand Up @@ -77,16 +77,13 @@ def norm_gaussian_spherical(bas_n, bas_exp):
Returns:
torch.tensor: normalization factor
"""

from scipy.special import factorial2 as f2

bas_n = torch.tensor(bas_n)
bas_n = bas_n + 1.
exp1 = 0.25 * (2. * bas_n + 1.)

A = torch.tensor(bas_exp)**exp1
B = 2**(2. * bas_n + 3. / 2)
C = torch.as_tensor(f2(2 * bas_n.int() - 1) * np.pi **
C = torch.as_tensor(double_factorial(2 * bas_n.int() - 1) * np.pi **
0.5).type(torch.get_default_dtype())

return torch.sqrt(B / C) * A
Expand All @@ -106,22 +103,20 @@ def norm_slater_cartesian(a, b, c, n, exp):
Returns:
torch.tensor: normalization factor
"""
from scipy.special import factorial2 as f2

lvals = a + b + c + n + 1.

lfact = torch.as_tensor([np.math.factorial(int(2 * i))
for i in lvals]).type(torch.get_default_dtype())

prefact = 4 * np.pi * lfact / ((2 * exp)**(2 * lvals + 1))

num = torch.as_tensor(f2(2 * a.astype('int') - 1) *
f2(2 * b.astype('int') - 1) *
f2(2 * c.astype('int') - 1)
num = torch.as_tensor(double_factorial(2 * a.astype('int') - 1) *
double_factorial(2 * b.astype('int') - 1) *
double_factorial(2 * c.astype('int') - 1)
).type(torch.get_default_dtype())

denom = torch.as_tensor(
f2((2 * a + 2 * b + 2 * c + 1).astype('int')
double_factorial((2 * a + 2 * b + 2 * c + 1).astype('int')
)).type(torch.get_default_dtype())

return torch.sqrt(1. / (prefact * num / denom))
Expand All @@ -135,22 +130,19 @@ def norm_gaussian_cartesian(a, b, c, exp):
a (torch.tensor): exponent of x
b (torch.tensor): exponent of y
c (torch.tensor): exponent of z
exp (torch.tensor): Sater exponent
exp (torch.tensor): Slater exponent
Returns:
torch.tensor: normalization factor
"""

from scipy.special import factorial2 as f2

pref = torch.as_tensor((2 * exp / np.pi)**(0.75))
am1 = (2 * a - 1).astype('int')
x = (4 * exp)**(a / 2) / torch.sqrt(torch.as_tensor(f2(am1)))
x = (4 * exp)**(a / 2) / torch.sqrt(torch.as_tensor(double_factorial(am1)))

bm1 = (2 * b - 1).astype('int')
y = (4 * exp)**(b / 2) / torch.sqrt(torch.as_tensor(f2(bm1)))
y = (4 * exp)**(b / 2) / torch.sqrt(torch.as_tensor(double_factorial(bm1)))

cm1 = (2 * c - 1).astype('int')
z = (4 * exp)**(c / 2) / torch.sqrt(torch.as_tensor(f2(cm1)))
z = (4 * exp)**(c / 2) / torch.sqrt(torch.as_tensor(double_factorial(cm1)))

return (pref * x * y * z).type(torch.get_default_dtype())
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from torch.autograd import Variable, grad, gradcheck
from torch.autograd import Variable, grad
from qmctorch.wavefunction.jastrows.distance import ElectronElectronDistance
import unittest
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import numpy as np
import torch
from torch.autograd import Variable, grad, gradcheck
from torch.autograd import Variable, grad

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.fully_connected_jastrow_kernel import FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.fully_connected_jastrow_kernel import FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
3 changes: 2 additions & 1 deletion tests/wavefunction/jastrows/elec_elec/test_pade_jastrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_kernel import PadeJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_polynomial_kernel import PadeJastrowPolynomialKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_kernel import PadeJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_polynomial_kernel import PadeJastrowPolynomialKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import numpy as np
import torch
from torch.autograd import Variable, grad, gradcheck
from torch.autograd import Variable, grad
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.jastrow_factor_electron_electron_nuclei import JastrowFactorElectronElectronNuclei
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.kernels.boys_handy_jastrow_kernel import BoysHandyJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import numpy as np
import torch
from torch.autograd import Variable, grad, gradcheck
from torch.autograd import Variable, grad
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.jastrow_factor_electron_electron_nuclei import JastrowFactorElectronElectronNuclei
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.kernels.fully_connected_jastrow_kernel import FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from torch.autograd import Variable, grad, gradcheck
from qmctorch.wavefunction.jastrows.elec_nuclei.jastrow_factor_electron_nuclei import JastrowFactorElectronNuclei
from qmctorch.wavefunction.jastrows.elec_nuclei.kernels import FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from qmctorch.wavefunction.jastrows.elec_nuclei.jastrow_factor_electron_nuclei import JastrowFactorElectronNuclei
from qmctorch.wavefunction.jastrows.elec_nuclei.kernels.pade_jastrow_kernel import PadeJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
3 changes: 2 additions & 1 deletion tests/wavefunction/jastrows/test_combined_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from qmctorch.wavefunction.jastrows.elec_elec.kernels import PadeJastrowKernel as PadeJastrowKernelElecElec
from qmctorch.wavefunction.jastrows.elec_nuclei.kernels import PadeJastrowKernel as PadeJastrowKernelElecNuc
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.kernels import BoysHandyJastrowKernel, FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from qmctorch.scf import Molecule
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelBase
from qmctorch.wavefunction.jastrows.distance.electron_electron_distance import ElectronElectronDistance
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from qmctorch.scf import Molecule
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
from qmctorch.wavefunction.jastrows.distance.electron_electron_distance import ElectronElectronDistance
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from qmctorch.scf import Molecule
from qmctorch.wavefunction.orbitals.backflow.backflow_transformation import BackFlowTransformation
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from qmctorch.scf import Molecule
from qmctorch.wavefunction.orbitals.backflow.orbital_dependent_backflow_transformation import OrbitalDependentBackFlowTransformation
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
3 changes: 2 additions & 1 deletion tests/wavefunction/orbitals/test_ao_derivatives_adf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from qmctorch.scf import Molecule
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.utils import set_torch_double_precision
from ...path_utils import PATH_TEST

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
3 changes: 2 additions & 1 deletion tests/wavefunction/orbitals/test_ao_derivatives_pyscf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from qmctorch.scf import Molecule
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.wavefunction.orbitals.atomic_orbitals_backflow import AtomicOrbitalsBackFlow
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.wavefunction.orbitals.atomic_orbitals_orbital_dependent_backflow import AtomicOrbitalsOrbitalDependentBackFlow
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Loading

0 comments on commit f4a5fb0

Please sign in to comment.