From 69bb397ee7b9424102c8e4ed778ff499e1917b00 Mon Sep 17 00:00:00 2001 From: Nico Date: Thu, 5 Dec 2024 17:47:52 +0100 Subject: [PATCH] fix f2 issue --- qmctorch/utils/algebra_utils.py | 15 +++++++++- .../wavefunction/orbitals/norm_orbital.py | 28 +++++++------------ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/qmctorch/utils/algebra_utils.py b/qmctorch/utils/algebra_utils.py index 50f5d5bf..726c5e53 100644 --- a/qmctorch/utils/algebra_utils.py +++ b/qmctorch/utils/algebra_utils.py @@ -1,5 +1,5 @@ import torch - +from scipy.special import factorial2 as f2 def btrace(M): """Computes the trace of batched matrices @@ -43,6 +43,19 @@ def bdet2(M): return M[..., 0, 0] * M[..., 1, 1] - M[..., 0, 1] * M[..., 1, 0] +def double_factorial(input): + """Computes the double factorial of an array of int + + Args: + input (List): input numbers + + Returns: + List: values of the double factorial + """ + output = f2(input) + return [1 if o==0 else o for o in output] + + class BatchDeterminant(torch.autograd.Function): @staticmethod diff --git a/qmctorch/wavefunction/orbitals/norm_orbital.py b/qmctorch/wavefunction/orbitals/norm_orbital.py index 77ee6990..6f3be4ff 100644 --- a/qmctorch/wavefunction/orbitals/norm_orbital.py +++ b/qmctorch/wavefunction/orbitals/norm_orbital.py @@ -1,6 +1,6 @@ import torch import numpy as np - +from ...utils.algebra_utils import double_factorial def atomic_orbital_norm(basis): """Computes the norm of the atomic orbitals @@ -77,16 +77,13 @@ def norm_gaussian_spherical(bas_n, bas_exp): Returns: torch.tensor: normalization factor """ - - from scipy.special import factorial2 as f2 - bas_n = torch.tensor(bas_n) bas_n = bas_n + 1. exp1 = 0.25 * (2. * bas_n + 1.) A = torch.tensor(bas_exp)**exp1 B = 2**(2. * bas_n + 3. / 2) - C = torch.as_tensor(f2(2 * bas_n.int() - 1) * np.pi ** + C = torch.as_tensor(double_factorial(2 * bas_n.int() - 1) * np.pi ** 0.5).type(torch.get_default_dtype()) return torch.sqrt(B / C) * A @@ -106,8 +103,6 @@ def norm_slater_cartesian(a, b, c, n, exp): Returns: torch.tensor: normalization factor """ - from scipy.special import factorial2 as f2 - lvals = a + b + c + n + 1. lfact = torch.as_tensor([np.math.factorial(int(2 * i)) @@ -115,13 +110,13 @@ def norm_slater_cartesian(a, b, c, n, exp): prefact = 4 * np.pi * lfact / ((2 * exp)**(2 * lvals + 1)) - num = torch.as_tensor(f2(2 * a.astype('int') - 1) * - f2(2 * b.astype('int') - 1) * - f2(2 * c.astype('int') - 1) + num = torch.as_tensor(double_factorial(2 * a.astype('int') - 1) * + double_factorial(2 * b.astype('int') - 1) * + double_factorial(2 * c.astype('int') - 1) ).type(torch.get_default_dtype()) denom = torch.as_tensor( - f2((2 * a + 2 * b + 2 * c + 1).astype('int') + double_factorial((2 * a + 2 * b + 2 * c + 1).astype('int') )).type(torch.get_default_dtype()) return torch.sqrt(1. / (prefact * num / denom)) @@ -135,22 +130,19 @@ def norm_gaussian_cartesian(a, b, c, exp): a (torch.tensor): exponent of x b (torch.tensor): exponent of y c (torch.tensor): exponent of z - exp (torch.tensor): Sater exponent + exp (torch.tensor): Slater exponent Returns: torch.tensor: normalization factor """ - - from scipy.special import factorial2 as f2 - pref = torch.as_tensor((2 * exp / np.pi)**(0.75)) am1 = (2 * a - 1).astype('int') - x = (4 * exp)**(a / 2) / torch.sqrt(torch.as_tensor(f2(am1))) + x = (4 * exp)**(a / 2) / torch.sqrt(torch.as_tensor(double_factorial(am1))) bm1 = (2 * b - 1).astype('int') - y = (4 * exp)**(b / 2) / torch.sqrt(torch.as_tensor(f2(bm1))) + y = (4 * exp)**(b / 2) / torch.sqrt(torch.as_tensor(double_factorial(bm1))) cm1 = (2 * c - 1).astype('int') - z = (4 * exp)**(c / 2) / torch.sqrt(torch.as_tensor(f2(cm1))) + z = (4 * exp)**(c / 2) / torch.sqrt(torch.as_tensor(double_factorial(cm1))) return (pref * x * y * z).type(torch.get_default_dtype())