Skip to content

Commit

Permalink
check for jastow before resetting params
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoRenaud committed Oct 25, 2023
1 parent 49a6ed7 commit 9938815
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 48 deletions.
4 changes: 2 additions & 2 deletions docs/example/single_point/h2o_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

# define the molecule
mol = Molecule(atom='water.xyz', unit='angs',
calculator='pyscf', basis='cc-pvdz',
calculator='pyscf', basis='sto-3g' ,
name='water', redo_scf=True)

# define the wave function
wf = SlaterJastrow(mol, kinetic='jacobi',
configs='ground_state')

# sampler
sampler = Metropolis(nwalkers=100, nstep=500, step_size=0.25,
sampler = Metropolis(nwalkers=1000, nstep=500, step_size=0.25,
nelec=wf.nelec, ndim=wf.ndim,
init=mol.domain('atomic'),
move={'type': 'all-elec', 'proba': 'normal'})
Expand Down
96 changes: 58 additions & 38 deletions qmctorch/scf/calculator/pyscf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,21 @@ def get_basis_data(self, mol, rhf):
rhf {pyscf.scf} -- scf object
"""

# sphereical quantum nummbers
mvalues = {0: [0], 1: [-1,0,1], 2: [-2,-1,0,1,2]}

# cartesian quantum numbers
kx = {0: [0], 1: [1, 0, 0], 2: [2, 1, 1, 0, 0, 0]}
ky = {0: [0], 1: [0, 1, 0], 2: [0, 1, 0, 2, 1, 0]}
kz = {0: [0], 1: [0, 0, 1], 2: [0, 0, 1, 0, 1, 2]}

basis = SimpleNamespace()
basis.TotalEnergy = rhf.e_tot
basis.radial_type = 'gto_pure'
basis.harmonics_type = 'cart'
if self.basis_name.startswith('cc-'):
basis.harmonics_type = 'cart'
else:
basis.harmonics_type = 'cart'

# number of AO / MO
# can be different if d or f orbs are present
Expand All @@ -72,23 +79,23 @@ def get_basis_data(self, mol, rhf):
# init bas properties
bas_coeff, bas_exp = [], []
index_ctr = []
bas_n, bas_l = [], []
bas_n, bas_m, bas_l = [], [], []
bas_zeta = []
bas_kx, bas_ky, bas_kz = [], [], []
bas_n = []
bas_n_ori = self.get_bas_n(mol)

iao = 0
ishell = 0
for ibas in range(mol.nbas):

# number of zeta function per bas
nzeta = mol.bas_nctr(ibas)
# number of contracted gto per shell
nctr = mol.bas_nctr(ibas)

# number of contracted gaussian in that bas
nctr = mol.bas_nprim(ibas)
# number of primitive gaussian in that shell
nprim = mol.bas_nprim(ibas)

# number of ao from that bas <= ?
mult = mol.bas_len_cart(ibas)
# number of cartesian component of that bas ?
ncart_comp = mol.bas_len_cart(ibas)

# quantum numbers
n = bas_n_ori[ibas]
Expand All @@ -98,46 +105,58 @@ def get_basis_data(self, mol, rhf):
coeffs = mol.bas_ctr_coeff(ibas)
exps = mol.bas_exp(ibas)

# deal with the multiple zeta
if coeffs.shape != (nctr, nzeta):
# deal with multiple zeta
if coeffs.shape != (nprim, nctr):
raise ValueError('Contraction coefficients issue')

if exps.shape != (nctr, nzeta):
exps = exps[:,np.newaxis]
exps = np.tile(exps,nzeta)
nctr *= nzeta

# coeffs/exp
bas_coeff += coeffs.flatten().tolist() * mult
bas_exp += exps.flatten().tolist() * mult
# if nctr > 1:
# coeffs /= np.array(range(1,nctr+1))
# exps = exps[:,np.newaxis]
# exps = np.tile(exps,nctr)
# nprim *= nctr

ictr = 0
while ictr < nctr:

n = bas_n_ori[ishell]
coeffs_ictr = coeffs[:,ictr] / (ictr+1)

# coeffs/exp
bas_coeff += coeffs_ictr.flatten().tolist() * ncart_comp
bas_exp += exps.flatten().tolist() * ncart_comp

# get quantum numbers per bas
bas_n += [n] * nprim * ncart_comp
bas_l += [lval] * nprim * ncart_comp

# record the zetas per bas
bas_zeta += [nctr] * nprim * ncart_comp

# get quantum numbers per bas
bas_n += [n] * nctr * mult
bas_l += [lval] * nctr * mult
# number of shell per atoms
nshells[mol.bas_atom(ibas)] += nprim * ncart_comp

#record the zetas per bas
bas_zeta += [nzeta] * nctr * mult
for _ in range(ncart_comp):
index_ctr += [iao] * nprim
iao += 1

# number of shell per atoms
nshells[mol.bas_atom(ibas)] += nctr * mult
for m in mvalues[lval]:
bas_m += [m] * nprim

for _ in range(mult):
index_ctr += [iao] * nctr
iao += 1
for k in kx[lval]:
bas_kx += [k] * nprim

for k in kx[lval]:
bas_kx += [k] * nctr
for k in ky[lval]:
bas_ky += [k] * nprim

for k in ky[lval]:
bas_ky += [k] * nctr
for k in kz[lval]:
bas_kz += [k] * nprim

for k in kz[lval]:
bas_kz += [k] * nctr
ictr += 1
ishell += 1

bas_norm = []
for expnt, lval, zeta in zip(bas_exp, bas_l, bas_zeta):
bas_norm.append(mol.gto_norm(lval, expnt)/zeta)
for expnt, lval in zip(bas_exp, bas_l):
bas_norm.append(mol.gto_norm(lval, expnt))

basis.nshells = nshells
basis.index_ctr = index_ctr
Expand All @@ -160,6 +179,7 @@ def get_basis_data(self, mol, rhf):

basis.bas_n = bas_n
basis.bas_l = bas_l
basis.bas_m = bas_m

# the cartesian gto are all :
# x^a y^b z^c exp(-zeta r)
Expand Down Expand Up @@ -205,7 +225,7 @@ def get_bas_n(mol):
nlabel = [l[2][1] for l in unique_labels]

if np.any([nl not in recognized_labels for nl in nlabel]):
log.error('QMCTORCH only implement the following orbitals: {0}', recognized_labels)
log.error('the pyscf calculator only supports the following orbitals: {0}', recognized_labels)
log.error('The following orbitals have been found: {0}', nlabel)
log.error('Using the basis set: {0}', mol.basis)
raise ValueError('Basis set not supported')
Expand Down
5 changes: 3 additions & 2 deletions qmctorch/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def set_params_requires_grad(self, wf_params=True, geo_params=False):

self.wf.fc.weight.requires_grad = wf_params

for param in self.wf.jastrow.parameters():
param.requires_grad = wf_params
if hasattr(self.wf, 'jastrow'):
for param in self.wf.jastrow.parameters():
param.requires_grad = wf_params

# no opt the atom positions
self.wf.ao.atom_coords.requires_grad = geo_params
Expand Down
3 changes: 2 additions & 1 deletion qmctorch/wavefunction/orbitals/norm_orbital.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ def norm_gaussian_spherical(bas_n, bas_exp):

from scipy.special import factorial2 as f2

bas_n = torch.tensor(bas_n)
bas_n = bas_n + 1.
exp1 = 0.25 * (2. * bas_n + 1.)

A = bas_exp**exp1
A = torch.tensor(bas_exp)**exp1
B = 2**(2. * bas_n + 3. / 2)
C = torch.as_tensor(f2(2 * bas_n.int() - 1) * np.pi **
0.5).type(torch.get_default_dtype())
Expand Down
16 changes: 11 additions & 5 deletions qmctorch/wavefunction/orbitals/spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,20 @@ def SphericalHarmonics(xyz, l, m, derivative=0, sum_grad=True, sum_hess=True):
if not sum_hess:
raise NotImplementedError(
'SphericalHarmonics cannot return individual component of the laplacian')
if derivative > 2:
raise NotImplementedError(
"Spherical Harmonics only accpet derivative=0,1,2 (%d found)" % derivative)

if not isinstance(derivative, list):
derivative = [derivative]


if sum_grad:
return get_spherical_harmonics(xyz, l, m, derivative)
output = [get_spherical_harmonics(xyz, l, m, d) for d in derivative]
if len(derivative) == 1:
return output[0]
else:
return output

else:
if derivative != 1:
if derivative != [1]:
raise ValueError(
'Gradient of the spherical harmonics require derivative=1')
return get_grad_spherical_harmonics(xyz, l, m)
Expand Down

0 comments on commit 9938815

Please sign in to comment.