Skip to content

Commit

Permalink
fix size issue
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoRenaud committed Oct 19, 2023
1 parent 40d136f commit 49a6ed7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 21 deletions.
17 changes: 9 additions & 8 deletions docs/example/single_point/h2o_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

# define the molecule
mol = Molecule(atom='water.xyz', unit='angs',
calculator='pyscf', basis='sto-3g', name='water')
calculator='pyscf', basis='cc-pvdz',
name='water', redo_scf=True)

# define the wave function
wf = SlaterJastrow(mol, kinetic='jacobi',
Expand All @@ -24,11 +25,11 @@
# single point
obs = solver.single_point()

# reconfigure sampler
solver.sampler.ntherm = 0
solver.sampler.ndecor = 5
# # reconfigure sampler
# solver.sampler.ntherm = 0
# solver.sampler.ndecor = 5

# compute the sampling traj
pos = solver.sampler(solver.wf.pdf)
obs = solver.sampling_traj(pos)
plot_walkers_traj(obs.local_energy, walkers='mean')
# # compute the sampling traj
# pos = solver.sampler(solver.wf.pdf)
# obs = solver.sampling_traj(pos)
# plot_walkers_traj(obs.local_energy, walkers='mean')
54 changes: 41 additions & 13 deletions qmctorch/scf/calculator/pyscf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pyscf import gto, scf, dft
import shutil
from .calculator_base import CalculatorBase

from ... import log

class CalculatorPySCF(CalculatorBase):

Expand Down Expand Up @@ -73,37 +73,56 @@ def get_basis_data(self, mol, rhf):
bas_coeff, bas_exp = [], []
index_ctr = []
bas_n, bas_l = [], []
bas_zeta = []
bas_kx, bas_ky, bas_kz = [], [], []
bas_n = []
bas_n_ori = self.get_bas_n(mol)

iao = 0
for ibas in range(mol.nbas):

# number of zeta function per bas
nzeta = mol.bas_nctr(ibas)

# number of contracted gaussian in that bas
# nctr = mol.bas_nctr(ibas)
nctr = mol.bas_nprim(ibas)

# number of ao from that bas
mult = mol.bas_len_cart(ibas)
# number of ao from that bas <= ?
mult = mol.bas_len_cart(ibas)

# quantum numbers
n = bas_n_ori[ibas]
lval = mol.bas_angular(ibas)

# get qn per bas
bas_n += [n] * nctr * mult
bas_l += [lval] * nctr * mult
# coeffs and exponents
coeffs = mol.bas_ctr_coeff(ibas)
exps = mol.bas_exp(ibas)

# deal with the multiple zeta
if coeffs.shape != (nctr, nzeta):
raise ValueError('Contraction coefficients issue')

if exps.shape != (nctr, nzeta):
exps = exps[:,np.newaxis]
exps = np.tile(exps,nzeta)
nctr *= nzeta

# coeffs/exp
bas_coeff += mol.bas_ctr_coeff(
ibas).flatten().tolist() * mult
bas_exp += mol.bas_exp(ibas).flatten().tolist() * mult
bas_coeff += coeffs.flatten().tolist() * mult
bas_exp += exps.flatten().tolist() * mult


# get quantum numbers per bas
bas_n += [n] * nctr * mult
bas_l += [lval] * nctr * mult

#record the zetas per bas
bas_zeta += [nzeta] * nctr * mult

# number of shell per atoms
nshells[mol.bas_atom(ibas)] += nctr * mult

for m in range(mult):
for _ in range(mult):
index_ctr += [iao] * nctr
iao += 1

Expand All @@ -117,8 +136,8 @@ def get_basis_data(self, mol, rhf):
bas_kz += [k] * nctr

bas_norm = []
for expnt, lval in zip(bas_exp, bas_l):
bas_norm.append(mol.gto_norm(lval, expnt))
for expnt, lval, zeta in zip(bas_exp, bas_l, bas_zeta):
bas_norm.append(mol.gto_norm(lval, expnt)/zeta)

basis.nshells = nshells
basis.index_ctr = index_ctr
Expand Down Expand Up @@ -175,12 +194,21 @@ def get_atoms_str(self):
@staticmethod
def get_bas_n(mol):

recognized_labels = ['s','p','d']

label2int = {'s': 1, 'p': 2, 'd': 3}
labels = [l[:3] for l in mol.cart_labels(fmt=False)]
unique_labels = []
for l in labels:
if l not in unique_labels:
unique_labels.append(l)
nlabel = [l[2][1] for l in unique_labels]

if np.any([nl not in recognized_labels for nl in nlabel]):
log.error('QMCTORCH only implement the following orbitals: {0}', recognized_labels)
log.error('The following orbitals have been found: {0}', nlabel)
log.error('Using the basis set: {0}', mol.basis)
raise ValueError('Basis set not supported')

n = [label2int[nl] for nl in nlabel]
return n

0 comments on commit 49a6ed7

Please sign in to comment.