diff --git a/qmctorch/wavefunction/slater_jastrow.py b/qmctorch/wavefunction/slater_jastrow.py index 6f919e2f..edde7ae7 100644 --- a/qmctorch/wavefunction/slater_jastrow.py +++ b/qmctorch/wavefunction/slater_jastrow.py @@ -194,7 +194,9 @@ def init_jastrow(self, jastrow): # create a simple Pade Jastrow factor as default if jastrow == 'default': - self.jastrow = JastrowFactorElectronElectron(self.mol, PadeJastrowKernel) + self.jastrow = JastrowFactorElectronElectron(self.mol, + PadeJastrowKernel, + cuda=self.cuda) elif isinstance(jastrow, list): self.jastrow = CombineJastrow(jastrow)