From 49a6ed7f4e68f4f16c1755b9cf84a074fea761ae Mon Sep 17 00:00:00 2001 From: Nico Date: Thu, 19 Oct 2023 16:19:21 +0200 Subject: [PATCH] fix size issue --- docs/example/single_point/h2o_sampling.py | 17 +++---- qmctorch/scf/calculator/pyscf.py | 54 +++++++++++++++++------ 2 files changed, 50 insertions(+), 21 deletions(-) diff --git a/docs/example/single_point/h2o_sampling.py b/docs/example/single_point/h2o_sampling.py index 35b56d58..7903f4a5 100644 --- a/docs/example/single_point/h2o_sampling.py +++ b/docs/example/single_point/h2o_sampling.py @@ -6,7 +6,8 @@ # define the molecule mol = Molecule(atom='water.xyz', unit='angs', - calculator='pyscf', basis='sto-3g', name='water') + calculator='pyscf', basis='cc-pvdz', + name='water', redo_scf=True) # define the wave function wf = SlaterJastrow(mol, kinetic='jacobi', @@ -24,11 +25,11 @@ # single point obs = solver.single_point() -# reconfigure sampler -solver.sampler.ntherm = 0 -solver.sampler.ndecor = 5 +# # reconfigure sampler +# solver.sampler.ntherm = 0 +# solver.sampler.ndecor = 5 -# compute the sampling traj -pos = solver.sampler(solver.wf.pdf) -obs = solver.sampling_traj(pos) -plot_walkers_traj(obs.local_energy, walkers='mean') +# # compute the sampling traj +# pos = solver.sampler(solver.wf.pdf) +# obs = solver.sampling_traj(pos) +# plot_walkers_traj(obs.local_energy, walkers='mean') diff --git a/qmctorch/scf/calculator/pyscf.py b/qmctorch/scf/calculator/pyscf.py index 81d58f89..4863303f 100644 --- a/qmctorch/scf/calculator/pyscf.py +++ b/qmctorch/scf/calculator/pyscf.py @@ -4,7 +4,7 @@ from pyscf import gto, scf, dft import shutil from .calculator_base import CalculatorBase - +from ... import log class CalculatorPySCF(CalculatorBase): @@ -73,6 +73,7 @@ def get_basis_data(self, mol, rhf): bas_coeff, bas_exp = [], [] index_ctr = [] bas_n, bas_l = [], [] + bas_zeta = [] bas_kx, bas_ky, bas_kz = [], [], [] bas_n = [] bas_n_ori = self.get_bas_n(mol) @@ -80,30 +81,48 @@ def get_basis_data(self, mol, rhf): iao = 0 for ibas in range(mol.nbas): + # number of zeta function per bas + nzeta = mol.bas_nctr(ibas) + # number of contracted gaussian in that bas - # nctr = mol.bas_nctr(ibas) nctr = mol.bas_nprim(ibas) - # number of ao from that bas - mult = mol.bas_len_cart(ibas) + # number of ao from that bas <= ? + mult = mol.bas_len_cart(ibas) # quantum numbers n = bas_n_ori[ibas] lval = mol.bas_angular(ibas) - # get qn per bas - bas_n += [n] * nctr * mult - bas_l += [lval] * nctr * mult + # coeffs and exponents + coeffs = mol.bas_ctr_coeff(ibas) + exps = mol.bas_exp(ibas) + + # deal with the multiple zeta + if coeffs.shape != (nctr, nzeta): + raise ValueError('Contraction coefficients issue') + + if exps.shape != (nctr, nzeta): + exps = exps[:,np.newaxis] + exps = np.tile(exps,nzeta) + nctr *= nzeta # coeffs/exp - bas_coeff += mol.bas_ctr_coeff( - ibas).flatten().tolist() * mult - bas_exp += mol.bas_exp(ibas).flatten().tolist() * mult + bas_coeff += coeffs.flatten().tolist() * mult + bas_exp += exps.flatten().tolist() * mult + + + # get quantum numbers per bas + bas_n += [n] * nctr * mult + bas_l += [lval] * nctr * mult + + #record the zetas per bas + bas_zeta += [nzeta] * nctr * mult # number of shell per atoms nshells[mol.bas_atom(ibas)] += nctr * mult - for m in range(mult): + for _ in range(mult): index_ctr += [iao] * nctr iao += 1 @@ -117,8 +136,8 @@ def get_basis_data(self, mol, rhf): bas_kz += [k] * nctr bas_norm = [] - for expnt, lval in zip(bas_exp, bas_l): - bas_norm.append(mol.gto_norm(lval, expnt)) + for expnt, lval, zeta in zip(bas_exp, bas_l, bas_zeta): + bas_norm.append(mol.gto_norm(lval, expnt)/zeta) basis.nshells = nshells basis.index_ctr = index_ctr @@ -175,6 +194,8 @@ def get_atoms_str(self): @staticmethod def get_bas_n(mol): + recognized_labels = ['s','p','d'] + label2int = {'s': 1, 'p': 2, 'd': 3} labels = [l[:3] for l in mol.cart_labels(fmt=False)] unique_labels = [] @@ -182,5 +203,12 @@ def get_bas_n(mol): if l not in unique_labels: unique_labels.append(l) nlabel = [l[2][1] for l in unique_labels] + + if np.any([nl not in recognized_labels for nl in nlabel]): + log.error('QMCTORCH only implement the following orbitals: {0}', recognized_labels) + log.error('The following orbitals have been found: {0}', nlabel) + log.error('Using the basis set: {0}', mol.basis) + raise ValueError('Basis set not supported') + n = [label2int[nl] for nl in nlabel] return n