Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP : Fix Nan Issue #172

Merged
merged 12 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
fail-fast: false
matrix:
version: [3.8]
version: ['3.8', '3.10']

steps:
- name: Cancel Previous Runs
Expand Down
23 changes: 0 additions & 23 deletions .github/workflows/draft-pdf.yml

This file was deleted.

2 changes: 1 addition & 1 deletion docs/example/scf/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
calc = ['pyscf', # pyscf
'adf', # adf 2019
'adf2019' # adf 2020+
][1]
][0]

# select an appropriate basis
basis = {
Expand Down
16 changes: 15 additions & 1 deletion qmctorch/utils/algebra_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

import numpy as np
from scipy.special import factorial2 as f2

def btrace(M):
"""Computes the trace of batched matrices
Expand Down Expand Up @@ -43,6 +44,19 @@ def bdet2(M):
return M[..., 0, 0] * M[..., 1, 1] - M[..., 0, 1] * M[..., 1, 0]


def double_factorial(input):
"""Computes the double factorial of an array of int

Args:
input (List): input numbers

Returns:
List: values of the double factorial
"""
output = f2(input)
return np.array([1 if o==0 else o for o in output])


class BatchDeterminant(torch.autograd.Function):

@staticmethod
Expand Down
8 changes: 4 additions & 4 deletions qmctorch/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

def set_torch_double_precision():
"""Set the default precision to double for all torch tensors."""
torch.set_default_dtype = torch.float64
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.float64)
# torch.set_default_tensor_type(torch.DoubleTensor)


def set_torch_single_precision():
"""Set the default precision to single for all torch tensors."""
torch.set_default_dtype = torch.float32
torch.set_default_tensor_type(torch.FloatTensor)
torch.set_default_dtype(torch.float32)
# torch.set_default_tensor_type(torch.FloatTensor)


def fast_power(x, k, mask0=None, mask2=None):
Expand Down
28 changes: 10 additions & 18 deletions qmctorch/wavefunction/orbitals/norm_orbital.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np

from ...utils.algebra_utils import double_factorial

def atomic_orbital_norm(basis):
"""Computes the norm of the atomic orbitals
Expand Down Expand Up @@ -77,16 +77,13 @@ def norm_gaussian_spherical(bas_n, bas_exp):
Returns:
torch.tensor: normalization factor
"""

from scipy.special import factorial2 as f2

bas_n = torch.tensor(bas_n)
bas_n = bas_n + 1.
exp1 = 0.25 * (2. * bas_n + 1.)

A = torch.tensor(bas_exp)**exp1
B = 2**(2. * bas_n + 3. / 2)
C = torch.as_tensor(f2(2 * bas_n.int() - 1) * np.pi **
C = torch.as_tensor(double_factorial(2 * bas_n.int() - 1) * np.pi **
0.5).type(torch.get_default_dtype())

return torch.sqrt(B / C) * A
Expand All @@ -106,22 +103,20 @@ def norm_slater_cartesian(a, b, c, n, exp):
Returns:
torch.tensor: normalization factor
"""
from scipy.special import factorial2 as f2

lvals = a + b + c + n + 1.

lfact = torch.as_tensor([np.math.factorial(int(2 * i))
for i in lvals]).type(torch.get_default_dtype())

prefact = 4 * np.pi * lfact / ((2 * exp)**(2 * lvals + 1))

num = torch.as_tensor(f2(2 * a.astype('int') - 1) *
f2(2 * b.astype('int') - 1) *
f2(2 * c.astype('int') - 1)
num = torch.as_tensor(double_factorial(2 * a.astype('int') - 1) *
double_factorial(2 * b.astype('int') - 1) *
double_factorial(2 * c.astype('int') - 1)
).type(torch.get_default_dtype())

denom = torch.as_tensor(
f2((2 * a + 2 * b + 2 * c + 1).astype('int')
double_factorial((2 * a + 2 * b + 2 * c + 1).astype('int')
)).type(torch.get_default_dtype())

return torch.sqrt(1. / (prefact * num / denom))
Expand All @@ -135,22 +130,19 @@ def norm_gaussian_cartesian(a, b, c, exp):
a (torch.tensor): exponent of x
b (torch.tensor): exponent of y
c (torch.tensor): exponent of z
exp (torch.tensor): Sater exponent
exp (torch.tensor): Slater exponent

Returns:
torch.tensor: normalization factor
"""

from scipy.special import factorial2 as f2

pref = torch.as_tensor((2 * exp / np.pi)**(0.75))
am1 = (2 * a - 1).astype('int')
x = (4 * exp)**(a / 2) / torch.sqrt(torch.as_tensor(f2(am1)))
x = (4 * exp)**(a / 2) / torch.sqrt(torch.as_tensor(double_factorial(am1)))

bm1 = (2 * b - 1).astype('int')
y = (4 * exp)**(b / 2) / torch.sqrt(torch.as_tensor(f2(bm1)))
y = (4 * exp)**(b / 2) / torch.sqrt(torch.as_tensor(double_factorial(bm1)))

cm1 = (2 * c - 1).astype('int')
z = (4 * exp)**(c / 2) / torch.sqrt(torch.as_tensor(f2(cm1)))
z = (4 * exp)**(c / 2) / torch.sqrt(torch.as_tensor(double_factorial(cm1)))

return (pref * x * y * z).type(torch.get_default_dtype())
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from torch.autograd import Variable, grad, gradcheck
from torch.autograd import Variable, grad
from qmctorch.wavefunction.jastrows.distance import ElectronElectronDistance
import unittest
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import numpy as np
import torch
from torch.autograd import Variable, grad, gradcheck
from torch.autograd import Variable, grad

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.fully_connected_jastrow_kernel import FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.fully_connected_jastrow_kernel import FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
3 changes: 2 additions & 1 deletion tests/wavefunction/jastrows/elec_elec/test_pade_jastrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_kernel import PadeJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_polynomial_kernel import PadeJastrowPolynomialKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_kernel import PadeJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_polynomial_kernel import PadeJastrowPolynomialKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import numpy as np
import torch
from torch.autograd import Variable, grad, gradcheck
from torch.autograd import Variable, grad
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.jastrow_factor_electron_electron_nuclei import JastrowFactorElectronElectronNuclei
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.kernels.boys_handy_jastrow_kernel import BoysHandyJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import numpy as np
import torch
from torch.autograd import Variable, grad, gradcheck
from torch.autograd import Variable, grad
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.jastrow_factor_electron_electron_nuclei import JastrowFactorElectronElectronNuclei
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.kernels.fully_connected_jastrow_kernel import FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from torch.autograd import Variable, grad, gradcheck
from qmctorch.wavefunction.jastrows.elec_nuclei.jastrow_factor_electron_nuclei import JastrowFactorElectronNuclei
from qmctorch.wavefunction.jastrows.elec_nuclei.kernels import FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from qmctorch.wavefunction.jastrows.elec_nuclei.jastrow_factor_electron_nuclei import JastrowFactorElectronNuclei
from qmctorch.wavefunction.jastrows.elec_nuclei.kernels.pade_jastrow_kernel import PadeJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
3 changes: 2 additions & 1 deletion tests/wavefunction/jastrows/test_combined_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from qmctorch.wavefunction.jastrows.elec_elec.kernels import PadeJastrowKernel as PadeJastrowKernelElecElec
from qmctorch.wavefunction.jastrows.elec_nuclei.kernels import PadeJastrowKernel as PadeJastrowKernelElecNuc
from qmctorch.wavefunction.jastrows.elec_elec_nuclei.kernels import BoysHandyJastrowKernel, FullyConnectedJastrowKernel
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from qmctorch.scf import Molecule
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelBase
from qmctorch.wavefunction.jastrows.distance.electron_electron_distance import ElectronElectronDistance
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from qmctorch.scf import Molecule
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
from qmctorch.wavefunction.jastrows.distance.electron_electron_distance import ElectronElectronDistance
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from qmctorch.scf import Molecule
from qmctorch.wavefunction.orbitals.backflow.backflow_transformation import BackFlowTransformation
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from qmctorch.scf import Molecule
from qmctorch.wavefunction.orbitals.backflow.orbital_dependent_backflow_transformation import OrbitalDependentBackFlowTransformation
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
3 changes: 2 additions & 1 deletion tests/wavefunction/orbitals/test_ao_derivatives_adf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from qmctorch.scf import Molecule
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.utils import set_torch_double_precision
from ...path_utils import PATH_TEST

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
3 changes: 2 additions & 1 deletion tests/wavefunction/orbitals/test_ao_derivatives_pyscf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from qmctorch.scf import Molecule
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.utils import set_torch_double_precision

torch.set_default_tensor_type(torch.DoubleTensor)
set_torch_double_precision()


def hess(out, pos):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.wavefunction.orbitals.atomic_orbitals_backflow import AtomicOrbitalsBackFlow
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.wavefunction.orbitals.atomic_orbitals_orbital_dependent_backflow import AtomicOrbitalsOrbitalDependentBackFlow
from qmctorch.wavefunction.orbitals.backflow.kernels import BackFlowKernelInverse
torch.set_default_tensor_type(torch.DoubleTensor)
from qmctorch.utils import set_torch_double_precision
set_torch_double_precision()

torch.manual_seed(101)
np.random.seed(101)
Expand Down
Loading
Loading