Skip to content

Commit

Permalink
gradient no backflow
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoRenaud committed Nov 28, 2023
1 parent 9784ea1 commit 9ab7ae8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
7 changes: 6 additions & 1 deletion tests/wavefunction/base_test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ class WaveFunctionBaseTest(unittest.TestCase):

def setUp(self):
"""Init the base test"""

def wf_placeholder(pos):
"""Callable for wf"""
return None

self.pos = None
self.wf = None
self.wf = wf_placeholder
self.nbatch = None

def test_forward(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ class ElecElecJastrowBaseTest(unittest.TestCase):

def setUp(self) -> None:
"""Init the test case"""
self.jastrow = None

def jastrow_callable(pos):
"""Empty callable for jastrow"""
return None

self.jastrow = jastrow_callable
self.nbatch = None
self.pos = None

Expand Down
6 changes: 5 additions & 1 deletion tests/wavefunction/orbitals/base_test_ao.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ class BaseTestAO:
class BaseTestAOderivatives(unittest.TestCase):

def setUp(self):
self.ao = None

def ao_callable(pos):
"""Callable for the AO"""
return None
self.ao = ao_callable
self.pos = None

def test_ao_deriv(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/wavefunction/test_slater_mgcn_graph_jastrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_kinetic_energy(self):

def test_gradients_wf(self):

grads = self.wf.gradients_jacobi(
grads = self.wf.gradients_jacobi_no_backflow(
self.pos, sum_grad=False).squeeze()
grad_auto = self.wf.gradients_autograd(self.pos)

Expand All @@ -181,7 +181,7 @@ def test_gradients_wf(self):

def test_gradients_pdf(self):

grads_pdf = self.wf.gradients_jacobi(self.pos, pdf=True)
grads_pdf = self.wf.gradients_jacobi_no_backflow(self.pos, pdf=True)
grads_auto = self.wf.gradients_autograd(self.pos, pdf=True)

assert torch.allclose(grads_pdf.sum(), grads_auto.sum())
Expand Down

0 comments on commit 9ab7ae8

Please sign in to comment.