From 4c7c101f17166b91ba6bd15e9dcd164d9a184594 Mon Sep 17 00:00:00 2001 From: Nicolas Renaud Date: Thu, 9 Sep 2021 14:31:36 +0200 Subject: [PATCH] refac --- example/ho1d.py | 22 +- example/morse.py | 16 +- schrodinet/sampler/metropolis.py | 3 +- .../solver/{plot_potential.py => plot.py} | 20 +- schrodinet/solver/solver.py | 354 ++++++++++++++++++ schrodinet/solver/solver_base.py | 163 -------- schrodinet/solver/solver_potential.py | 199 ---------- .../{wf_potential.py => wave_function_1d.py} | 8 +- ...wf_potential_2d.py => wave_function_2d.py} | 4 +- tests/test_ho1d.py | 14 +- 10 files changed, 402 insertions(+), 401 deletions(-) rename schrodinet/solver/{plot_potential.py => plot.py} (96%) create mode 100644 schrodinet/solver/solver.py delete mode 100644 schrodinet/solver/solver_base.py delete mode 100644 schrodinet/solver/solver_potential.py rename schrodinet/wavefunction/{wf_potential.py => wave_function_1d.py} (94%) rename schrodinet/wavefunction/{wf_potential_2d.py => wave_function_2d.py} (96%) diff --git a/example/ho1d.py b/example/ho1d.py index 4f5d12d..9d326b1 100644 --- a/example/ho1d.py +++ b/example/ho1d.py @@ -2,9 +2,9 @@ from torch import optim from schrodinet.sampler.metropolis import Metropolis -from schrodinet.wavefunction.wf_potential import Potential -from schrodinet.solver.solver_potential import SolverPotential -from schrodinet.solver.plot_potential import plot_results_1d, plotter1d +from schrodinet.wavefunction.wave_function_1d import WaveFunction1D +from schrodinet.solver.solver import Solver +from schrodinet.solver.plot import plot_results_1d, plotter1d def pot_func(pos): @@ -21,12 +21,11 @@ def ho1d_sol(pos): domain, ncenter = {'min': -5., 'max': 5.}, 11 # wavefunction -wf = Potential(pot_func, domain, ncenter, fcinit='random', nelec=1, sigma=1.) +wf = WaveFunction1D(pot_func, domain, ncenter, sigma=1.) # sampler sampler = Metropolis(nwalkers=1000, nstep=200, - step_size=1., nelec=wf.nelec, - ndim=wf.ndim, init={'min': -5, 'max': 5}) + step_size=1., init=domain) # optimizer opt = optim.Adam(wf.parameters(), lr=0.05) @@ -35,12 +34,13 @@ def ho1d_sol(pos): scheduler = optim.lr_scheduler.StepLR(opt, step_size=100, gamma=0.75) # define solver -solver = SolverPotential(wf=wf, sampler=sampler, - optimizer=opt, scheduler=scheduler) +solver = Solver(wf=wf, sampler=sampler, + optimizer=opt, scheduler=scheduler) # train the wave function -#plotter = plotter1d(wf, domain, 100, sol=ho1d_sol) # , save='./image/') -solver.run(300, loss='energy-manual', plot=None, save='model.pth') +plotter = plotter1d(wf, domain, 100, sol=ho1d_sol) +solver.run(300, loss='energy-manual', plot=plotter, save='model.pth') # plot the final wave function -plot_results_1d(solver, domain, 100, ho1d_sol, e0=0.5, load='model.pth') +plot_results_1d(solver, domain, 100, ho1d_sol, + e0=0.5, load='model.pth') diff --git a/example/morse.py b/example/morse.py index 233d965..78dc02c 100644 --- a/example/morse.py +++ b/example/morse.py @@ -2,9 +2,9 @@ from torch import optim from schrodinet.sampler.metropolis import Metropolis -from schrodinet.wavefunction.wf_potential import Potential -from schrodinet.solver.solver_potential import SolverPotential -from schrodinet.solver.plot_potential import plot_results_1d, plotter1d +from schrodinet.wavefunction.wave_function_1d import WaveFunction1D +from schrodinet.solver.solver import Solver +from schrodinet.solver.plot import plot_results_1d def pot_func(pos): @@ -22,7 +22,8 @@ def ho1d_sol(pos): domain, ncenter = {'min': -3., 'max': 8.}, 51 # wavefunction -wf = Potential(pot_func, domain, ncenter, fcinit='random', nelec=1, sigma=1) +wf = WaveFunction1D(pot_func, domain, ncenter, + fcinit='random', nelec=1, sigma=1) # sampler sampler = Metropolis(nwalkers=1000, nstep=500, @@ -36,12 +37,13 @@ def ho1d_sol(pos): scheduler = optim.lr_scheduler.StepLR(opt, step_size=100, gamma=0.75) # define solver -solver = SolverPotential(wf=wf, sampler=sampler, - optimizer=opt, scheduler=scheduler) +solver = Solver(wf=wf, sampler=sampler, + optimizer=opt, scheduler=scheduler) # train the wave function #plotter = plotter1d(wf, domain, 100, sol=ho1d_sol) solver.run(300, loss='variance', plot=None, save='model.pth') # plot the final wave function -plot_results_1d(solver, domain, 100, ho1d_sol, e0=-0.125, load='model.pth') +plot_results_1d(solver, domain, 100, ho1d_sol, + e0=-0.125, load='model.pth') diff --git a/schrodinet/sampler/metropolis.py b/schrodinet/sampler/metropolis.py index 5db9aec..0913de9 100644 --- a/schrodinet/sampler/metropolis.py +++ b/schrodinet/sampler/metropolis.py @@ -81,7 +81,8 @@ def generate(self, pdf, ntherm=10, ndecor=100, pos=None, idecor += 1 if with_tqdm: - print("Acceptance rate %1.3f %%" % (rate/self.nstep*100)) + print("Acceptance rate %1.3f %%" % + (rate/self.nstep*100)) return torch.cat(pos) diff --git a/schrodinet/solver/plot_potential.py b/schrodinet/solver/plot.py similarity index 96% rename from schrodinet/solver/plot_potential.py rename to schrodinet/solver/plot.py index 2e4ce30..d249792 100644 --- a/schrodinet/solver/plot_potential.py +++ b/schrodinet/solver/plot.py @@ -107,12 +107,15 @@ def drawNow(self): self.lwf.set_ydata(vp) if self.plot_weight: - self.pweight.set_xdata(self.wf.rbf.centers.detach().numpy()) - self.pweight.set_ydata(self.wf.fc.weight.detach().numpy().T) + self.pweight.set_xdata( + self.wf.rbf.centers.detach().numpy()) + self.pweight.set_ydata( + self.wf.fc.weight.detach().numpy().T) if self.plot_grad: if self.wf.fc.weight.requires_grad: - self.pgrad.set_xdata(self.wf.rbf.centers.detach().numpy()) + self.pgrad.set_xdata( + self.wf.rbf.centers.detach().numpy()) data = (self.wf.fc.weight.grad.detach().numpy().T)**2 data /= np.linalg.norm(data) self.pgrad.set_ydata(data) @@ -167,7 +170,7 @@ def plot_wf_1d(net, domain, res, grad=False, hist=False, pot=True, sol=None, vals = net.wf(X) vn = vals.detach().numpy().flatten() vn /= np.max(vn) - ax.plot(xn, vn, color='black', linewidth=2, label='DeepQMC') + ax.plot(xn, vn, color='black', linewidth=2, label='Schrodinet') if pot: pot = net.wf.nuclear_potential(X).detach().numpy() @@ -213,7 +216,8 @@ def plot_results_1d(net, domain, res, sol=None, e0=None, load=None): ax0 = fig.add_subplot(211) ax1 = fig.add_subplot(212) - plot_wf_1d(net, domain, res, sol=sol, hist=False, ax=ax0, load=load) + plot_wf_1d(net, domain, res, sol=sol, + hist=False, ax=ax0, load=load) plot_observable(net.obs_dict, e0=e0, ax=ax1) plt.show() @@ -291,7 +295,8 @@ def __init__(self, wf, domain, res, pot=False, kinetic=False, sol=None): self.yy = pos[:, 1].reshape(res[0], res[1]) if callable(sol): - vs = sol(self.POS).view(self.res[0], self.res[1]).detach().numpy() + vs = sol(self.POS).view( + self.res[0], self.res[1]).detach().numpy() vs /= np.linalg.norm(vs) self.ax.plot_wireframe(self.xx, self.yy, vs, color='black', linewidth=1) @@ -458,7 +463,8 @@ def plot_wf_3d(net, domain, res, sol=None, if hist: pos = net.sample().detach().numpy() for ielec in range(net.wf.nelec): - ax.scatter(pos[:, ielec*3], pos[:, ielec*3+1], pos[:, ielec*3+2]) + ax.scatter(pos[:, ielec*3], + pos[:, ielec*3+1], pos[:, ielec*3+2]) if callable(sol): diff --git a/schrodinet/solver/solver.py b/schrodinet/solver/solver.py new file mode 100644 index 0000000..0c2b3ff --- /dev/null +++ b/schrodinet/solver/solver.py @@ -0,0 +1,354 @@ +import numpy as np +from types import SimpleNamespace +import torch +from torch.autograd import Variable +from torch.utils.data import DataLoader + +from schrodinet.solver.torch_utils import DataSet, Loss, ZeroOneClipper + + +class Solver(object): + + def __init__(self, wf=None, sampler=None, optimizer=None, scheduler=None): + + self.wf = wf + self.sampler = sampler + self.opt = optimizer + self.scheduler = scheduler + self.task = "wf_opt" + + # esampling + self.resampling(ntherm=-1, + resample=100, + resample_from_last=True, + resample_every=1) + + # observalbe + self.observable(['local_energy']) + + def run(self, nepoch, batchsize=None, save='model.pth', + loss='variance', plot=None, pos=None, with_tqdm=True): + '''Train the model. + + Arg: + nepoch : number of epoch + batchsize : size of the minibatch, if None take all points at once + pos : presampled electronic poition + obs_dict (dict, {name: []} ) : quantities to be computed during + the training + 'name' must refer to a method of + the Solver instance + ntherm : thermalization of the MC sampling. If negative (-N) takes + the last N entries + resample : number of MC step during the resampling + resample_from_last (bool) : if true use the previous position as + starting for the resampling + resample_every (int) : number of epch between resampling + loss : loss used ('energy','variance' or callable (for supervised) + plot : None or plotter instance from plot_utils.py to + interactively monitor the training + ''' + + # checkpoint file + self.save_model = save + + # sample the wave function + pos = self.sample( + pos=pos, ntherm=self.resample.ntherm, with_tqdm=with_tqdm) + + # determine the batching mode + if batchsize is None: + batchsize = len(pos) + + # change the number of steps + _nstep_save = self.sampler.nstep + self.sampler.nstep = self.resample.resample + + # create the data loader + self.dataset = DataSet(pos) + self.dataloader = DataLoader( + self.dataset, batch_size=batchsize) + + # get the loss + self.loss = Loss(self.wf, method=loss) + + # clipper for the fc weights + clipper = ZeroOneClipper() + + cumulative_loss = [] + min_loss = 1E3 + + for n in range(nepoch): + print('----------------------------------------') + print('epoch %d' % n) + + cumulative_loss = 0 + for ibatch, data in enumerate(self.dataloader): + + lpos = Variable(data) + lpos.requires_grad = True + + loss, eloc = self.evaluate_gradient( + lpos, self.loss.method) + cumulative_loss += loss + self.opt.step() + + if self.wf.fc.clip: + self.wf.fc.apply(clipper) + + if plot is not None: + plot.drawNow() + + if cumulative_loss < min_loss: + min_loss = self.save_checkpoint( + n, cumulative_loss, self.save_model) + + # get the observalbes + self.get_observable( + self.obs_dict, pos, eloc, ibatch=ibatch) + self.print_observable(cumulative_loss) + + print('----------------------------------------') + + # resample the data + if (n % self.resample.resample_every == 0) or (n == nepoch-1): + if self.resample.resample_from_last: + pos = pos.clone().detach() + else: + pos = None + pos = self.sample( + pos=pos, ntherm=self.resample.ntherm, with_tqdm=False) + self.dataloader.dataset.data = pos + + if self.scheduler is not None: + self.scheduler.step() + + # restore the sampler number of step + self.sampler.nstep = _nstep_save + + def evaluate_gradient(self, lpos, loss): + """Evaluate the gradient + + Arguments: + grad {str} -- method of the gradient (auto, manual) + lpos {torch.tensor} -- positions of the walkers + + + Returns: + tuple -- (loss, local energy) + """ + + if loss in ['variance', 'energy']: + loss, eloc = self._evaluate_grad_auto(lpos) + + elif loss == 'energy-manual': + loss, eloc = self._evaluate_grad_manual(lpos) + + else: + raise ValueError('Error in gradient method') + + if torch.isnan(loss): + raise ValueError("Nans detected in the loss") + + return loss, eloc + + def _evaluate_grad_auto(self, lpos): + """Evaluate the gradient using automatic diff of the required loss. + + Arguments: + lpos {torch.tensor} -- positions of the walkers + + Returns: + tuple -- (loss, local energy) + """ + + # compute the loss + loss, eloc = self.loss(lpos) + + # compute local gradients + self.opt.zero_grad() + loss.backward() + + return loss, eloc + + def _evaluate_grad_manual(self, lpos): + """Evaluate the gradient using a low variance method + + Arguments: + lpos {torch.tensor} -- positions of the walkers + + Returns: + tuple -- (loss, local energy) + """ + + ''' Get the gradient of the total energy + dE/dk = < (dpsi/dk)/psi (E_L - ) > + ''' + + # compute local energy and wf values + psi = self.wf(lpos) + eloc = self.wf.local_energy(lpos, wf=psi) + norm = 1./len(psi) + + # evaluate the prefactor of the grads + weight = eloc.clone() + weight -= torch.mean(eloc) + weight /= psi + weight *= 2. + weight *= norm + + # compute the gradients + self.opt.zero_grad() + psi.backward(weight) + + return torch.mean(eloc), eloc + + def resampling(self, ntherm=-1, resample=100, resample_from_last=True, + resample_every=1): + '''Configure the resampling options.''' + self.resample = SimpleNamespace() + self.resample.ntherm = ntherm + self.resample.resample = resample + self.resample.resample_from_last = resample_from_last + self.resample.resample_every = resample_every + + def observable(self, obs): + '''Create the observalbe we want to track.''' + + # reset the obs + self.obs_dict = {} + + for k in obs: + self.obs_dict[k] = [] + + if 'local_energy' not in self.obs_dict: + self.obs_dict['local_energy'] = [] + + if self.task == 'geo_opt' and 'geometry' not in self.obs_dict: + self.obs_dict['geometry'] = [] + + for key, p in zip(self.wf.state_dict().keys(), self.wf.parameters()): + if p.requires_grad: + self.obs_dict[key] = [] + self.obs_dict[key+'.grad'] = [] + + def sample(self, ntherm=-1, ndecor=100, with_tqdm=True, pos=None): + ''' sample the wave function.''' + + pos = self.sampler.generate( + self.wf.pdf, ntherm=ntherm, ndecor=ndecor, + with_tqdm=with_tqdm, pos=pos) + pos.requires_grad = True + return pos + + def get_observable(self, obs_dict, pos, eloc=None, ibatch=None, **kwargs): + '''compute all the required observable. + + Args : + obs_dict : a dictionanry with all keys + corresponding to a method of self.wf + **kwargs : the possible arguments for the methods + TODO : match the signature of the callables + ''' + + for obs in self. obs_dict.keys(): + + if obs == 'local_energy' and eloc is not None: + data = eloc.cpu().detach().numpy() + + if (ibatch is None) or (ibatch == 0): + self.obs_dict[obs].append(data) + else: + self.obs_dict[obs][-1] = np.append( + self.obs_dict[obs][-1], data) + + # store variational parameter + elif obs in self.wf.state_dict(): + layer, param = obs.split('.') + p = self.wf.__getattr__(layer).__getattr__(param) + self.obs_dict[obs].append(p.data.clone().numpy()) + + if p.grad is not None: + self.obs_dict[obs + + '.grad'].append(p.grad.clone().numpy()) + else: + self.obs_dict[obs + + '.grad'].append(torch.zeros_like(p.data)) + + # get the method + elif hasattr(self.wf, obs): + func = self.wf.__getattribute__(obs) + data = func(pos) + if isinstance(data, torch.Tensor): + data = data.detach().numpy() + self.obs_dict[obs].append(data) + + def print_observable(self, cumulative_loss, verbose=False): + """Print the observalbe to csreen + + Arguments: + cumulative_loss {float} -- current loss value + + Keyword Arguments: + verbose {bool} -- print all the observables (default: {False}) + """ + + for k in self.obs_dict: + + if k == 'local_energy': + + eloc = self.obs_dict['local_energy'][-1] + e = np.mean(eloc) + v = np.var(eloc) + err = np.sqrt(v/len(eloc)) + print('energy : %f +/- %f' % (e, err)) + print('variance : %f' % np.sqrt(v)) + + elif verbose: + print(k + ' : ', self.obs_dict[k][-1]) + print('loss %f' % (cumulative_loss)) + + def get_wf(self, x): + '''Get the value of the wave functions at x.''' + vals = self.wf(x) + return vals.detach().numpy().flatten() + + def energy(self, pos=None): + '''Get the energy of the wave function.''' + if pos is None: + pos = self.sample(ntherm=-1) + return self.wf.energy(pos) + + def variance(self, pos): + '''Get the variance of the wave function.''' + if pos is None: + pos = self.sample(ntherm=-1) + return self.wf.variance(pos) + + def single_point(self, pos=None, prt=True, ntherm=-1, ndecor=100): + '''Performs a single point calculation.''' + if pos is None: + pos = self.sample(ntherm=ntherm, ndecor=ndecor) + + e, s = self.wf._energy_variance(pos) + if prt: + print('Energy : ', e) + print('Variance : ', s) + return pos, e, s + + def save_checkpoint(self, epoch, loss, filename): + torch.save({ + 'epoch': epoch, + 'model_state_dict': self.wf.state_dict(), + 'optimzier_state_dict': self.opt.state_dict(), + 'loss': loss + }, filename) + return loss + + def sampling_traj(self, pos): + ndim = pos.shape[-1] + p = pos.view(-1, self.sampler.nwalkers, ndim) + el = [] + for ip in tqdm(p): + el.append(self.wf.local_energy(ip).detach().numpy()) + return {'local_energy': el, 'pos': p} diff --git a/schrodinet/solver/solver_base.py b/schrodinet/solver/solver_base.py deleted file mode 100644 index 25547b2..0000000 --- a/schrodinet/solver/solver_base.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch -from types import SimpleNamespace -from tqdm import tqdm -import numpy as np - -class SolverBase(object): - - def __init__(self, wf=None, sampler=None, optimizer=None): - - self.wf = wf - self.sampler = sampler - self.opt = optimizer - - def resampling(self, ntherm=-1, resample=100, resample_from_last=True, - resample_every=1): - '''Configure the resampling options.''' - self.resample = SimpleNamespace() - self.resample.ntherm = ntherm - self.resample.resample = resample - self.resample.resample_from_last = resample_from_last - self.resample.resample_every = resample_every - - def observable(self, obs): - '''Create the observalbe we want to track.''' - - # reset the obs - self.obs_dict = {} - - for k in obs: - self.obs_dict[k] = [] - - if 'local_energy' not in self.obs_dict: - self.obs_dict['local_energy'] = [] - - if self.task == 'geo_opt' and 'geometry' not in self.obs_dict: - self.obs_dict['geometry'] = [] - - for key, p in zip(self.wf.state_dict().keys(), self.wf.parameters()): - if p.requires_grad: - self.obs_dict[key] = [] - self.obs_dict[key+'.grad'] = [] - - def sample(self, ntherm=-1, ndecor=100, with_tqdm=True, pos=None): - ''' sample the wave function.''' - - pos = self.sampler.generate( - self.wf.pdf, ntherm=ntherm, ndecor=ndecor, - with_tqdm=with_tqdm, pos=pos) - pos.requires_grad = True - return pos - - def get_observable(self, obs_dict, pos, eloc=None, ibatch=None, **kwargs): - '''compute all the required observable. - - Args : - obs_dict : a dictionanry with all keys - corresponding to a method of self.wf - **kwargs : the possible arguments for the methods - TODO : match the signature of the callables - ''' - - for obs in self. obs_dict.keys(): - - if obs == 'local_energy' and eloc is not None: - data = eloc.cpu().detach().numpy() - - if (ibatch is None) or (ibatch == 0): - self.obs_dict[obs].append(data) - else: - self.obs_dict[obs][-1] = np.append( - self.obs_dict[obs][-1], data) - - # store variational parameter - elif obs in self.wf.state_dict(): - layer, param = obs.split('.') - p = self.wf.__getattr__(layer).__getattr__(param) - self.obs_dict[obs].append(p.data.clone().numpy()) - - if p.grad is not None: - self.obs_dict[obs+'.grad'].append(p.grad.clone().numpy()) - else: - self.obs_dict[obs+'.grad'].append(torch.zeros_like(p.data)) - - # get the method - elif hasattr(self.wf, obs): - func = self.wf.__getattribute__(obs) - data = func(pos) - if isinstance(data, torch.Tensor): - data = data.detach().numpy() - self.obs_dict[obs].append(data) - - def print_observable(self, cumulative_loss, verbose=False): - """Print the observalbe to csreen - - Arguments: - cumulative_loss {float} -- current loss value - - Keyword Arguments: - verbose {bool} -- print all the observables (default: {False}) - """ - - for k in self.obs_dict: - - if k == 'local_energy': - - eloc = self.obs_dict['local_energy'][-1] - e = np.mean(eloc) - v = np.var(eloc) - err = np.sqrt(v/len(eloc)) - print('energy : %f +/- %f' % (e, err)) - print('variance : %f' % np.sqrt(v)) - - elif verbose: - print(k + ' : ', self.obs_dict[k][-1]) - print('loss %f' % (cumulative_loss)) - - def get_wf(self, x): - '''Get the value of the wave functions at x.''' - vals = self.wf(x) - return vals.detach().numpy().flatten() - - def energy(self, pos=None): - '''Get the energy of the wave function.''' - if pos is None: - pos = self.sample(ntherm=-1) - return self.wf.energy(pos) - - def variance(self, pos): - '''Get the variance of the wave function.''' - if pos is None: - pos = self.sample(ntherm=-1) - return self.wf.variance(pos) - - def single_point(self, pos=None, prt=True, ntherm=-1, ndecor=100): - '''Performs a single point calculation.''' - if pos is None: - pos = self.sample(ntherm=ntherm, ndecor=ndecor) - - e, s = self.wf._energy_variance(pos) - if prt: - print('Energy : ', e) - print('Variance : ', s) - return pos, e, s - - def save_checkpoint(self, epoch, loss, filename): - torch.save({ - 'epoch': epoch, - 'model_state_dict': self.wf.state_dict(), - 'optimzier_state_dict': self.opt.state_dict(), - 'loss': loss - }, filename) - return loss - - def sampling_traj(self, pos): - ndim = pos.shape[-1] - p = pos.view(-1, self.sampler.nwalkers, ndim) - el = [] - for ip in tqdm(p): - el.append(self.wf.local_energy(ip).detach().numpy()) - return {'local_energy': el, 'pos': p} - - def run(self, nepoch, batchsize=None, loss='variance'): - raise NotImplementedError() diff --git a/schrodinet/solver/solver_potential.py b/schrodinet/solver/solver_potential.py deleted file mode 100644 index aceb3ac..0000000 --- a/schrodinet/solver/solver_potential.py +++ /dev/null @@ -1,199 +0,0 @@ -import numpy as np - -import torch -from torch.autograd import Variable -from torch.utils.data import DataLoader - -from schrodinet.solver.solver_base import SolverBase -from schrodinet.solver.torch_utils import DataSet, Loss, ZeroOneClipper - - -class SolverPotential(SolverBase): - - def __init__(self, wf=None, sampler=None, optimizer=None, - scheduler=None): - SolverBase.__init__(self, wf, sampler, optimizer) - self.scheduler = scheduler - self.task = "wf_opt" - - # esampling - self.resampling(ntherm=-1, - resample=100, - resample_from_last=True, - resample_every=1) - - # observalbe - self.observable(['local_energy']) - - def run(self, nepoch, batchsize=None, save='model.pth', loss='variance', - plot=None, pos = None, with_tqdm=True): - '''Train the model. - - Arg: - nepoch : number of epoch - batchsize : size of the minibatch, if None take all points at once - pos : presampled electronic poition - obs_dict (dict, {name: []} ) : quantities to be computed during - the training - 'name' must refer to a method of - the Solver instance - ntherm : thermalization of the MC sampling. If negative (-N) takes - the last N entries - resample : number of MC step during the resampling - resample_from_last (bool) : if true use the previous position as - starting for the resampling - resample_every (int) : number of epch between resampling - loss : loss used ('energy','variance' or callable (for supervised) - plot : None or plotter instance from plot_utils.py to - interactively monitor the training - ''' - - # checkpoint file - self.save_model = save - - # sample the wave function - pos = self.sample(pos=pos, ntherm=self.resample.ntherm, with_tqdm=with_tqdm) - - # determine the batching mode - if batchsize is None: - batchsize = len(pos) - - # change the number of steps - _nstep_save = self.sampler.nstep - self.sampler.nstep = self.resample.resample - - # create the data loader - self.dataset = DataSet(pos) - self.dataloader = DataLoader(self.dataset, batch_size=batchsize) - - # get the loss - self.loss = Loss(self.wf, method=loss) - - # clipper for the fc weights - clipper = ZeroOneClipper() - - cumulative_loss = [] - min_loss = 1E3 - - for n in range(nepoch): - print('----------------------------------------') - print('epoch %d' % n) - - cumulative_loss = 0 - for ibatch, data in enumerate(self.dataloader): - - lpos = Variable(data) - lpos.requires_grad = True - - loss, eloc = self.evaluate_gradient(lpos, self.loss.method) - cumulative_loss += loss - self.opt.step() - - if self.wf.fc.clip: - self.wf.fc.apply(clipper) - - if plot is not None: - plot.drawNow() - - if cumulative_loss < min_loss: - min_loss = self.save_checkpoint( - n, cumulative_loss, self.save_model) - - # get the observalbes - self.get_observable(self.obs_dict, pos, eloc, ibatch=ibatch) - self.print_observable(cumulative_loss) - - print('----------------------------------------') - - # resample the data - if (n % self.resample.resample_every == 0) or (n == nepoch-1): - if self.resample.resample_from_last: - pos = pos.clone().detach() - else: - pos = None - pos = self.sample( - pos=pos, ntherm=self.resample.ntherm, with_tqdm=False) - self.dataloader.dataset.data = pos - - if self.scheduler is not None: - self.scheduler.step() - - # restore the sampler number of step - self.sampler.nstep = _nstep_save - - def evaluate_gradient(self, lpos, loss): - """Evaluate the gradient - - Arguments: - grad {str} -- method of the gradient (auto, manual) - lpos {torch.tensor} -- positions of the walkers - - - Returns: - tuple -- (loss, local energy) - """ - - if loss in ['variance','energy']: - loss, eloc = self._evaluate_grad_auto(lpos) - - elif loss == 'energy-manual': - loss, eloc = self._evaluate_grad_manual(lpos) - - else: - raise ValueError('Error in gradient method') - - if torch.isnan(loss): - raise ValueError("Nans detected in the loss") - - return loss, eloc - - def _evaluate_grad_auto(self, lpos): - """Evaluate the gradient using automatic diff of the required loss. - - Arguments: - lpos {torch.tensor} -- positions of the walkers - - Returns: - tuple -- (loss, local energy) - """ - - # compute the loss - loss, eloc = self.loss(lpos) - - # compute local gradients - self.opt.zero_grad() - loss.backward() - - return loss, eloc - - def _evaluate_grad_manual(self, lpos): - """Evaluate the gradient using a low variance method - - Arguments: - lpos {torch.tensor} -- positions of the walkers - - Returns: - tuple -- (loss, local energy) - """ - - ''' Get the gradient of the total energy - dE/dk = < (dpsi/dk)/psi (E_L - ) > - ''' - - # compute local energy and wf values - psi = self.wf(lpos) - eloc = self.wf.local_energy(lpos, wf=psi) - norm = 1./len(psi) - - # evaluate the prefactor of the grads - weight = eloc.clone() - weight -= torch.mean(eloc) - weight /= psi - weight *= 2. - weight *= norm - - # compute the gradients - self.opt.zero_grad() - psi.backward(weight) - - return torch.mean(eloc), eloc \ No newline at end of file diff --git a/schrodinet/wavefunction/wf_potential.py b/schrodinet/wavefunction/wave_function_1d.py similarity index 94% rename from schrodinet/wavefunction/wf_potential.py rename to schrodinet/wavefunction/wave_function_1d.py index 5275ee9..0fdad01 100644 --- a/schrodinet/wavefunction/wf_potential.py +++ b/schrodinet/wavefunction/wave_function_1d.py @@ -5,11 +5,11 @@ from schrodinet.wavefunction.rbf import RBF_Gaussian as RBF -class Potential(WaveFunction): +class WaveFunction1D(WaveFunction): - def __init__(self, fpot, domain, ncenter, nelec=1, ndim=1, fcinit=0.1, + def __init__(self, fpot, domain, ncenter, nelec=1, ndim=1, fcinit='random', sigma=1.): - super(Potential, self).__init__(nelec, ndim) + super(WaveFunction1D, self).__init__(nelec, ndim) # get the RBF centers if not isinstance(ncenter, list): @@ -35,7 +35,7 @@ def __init__(self, fpot, domain, ncenter, nelec=1, ndim=1, fcinit=0.1, self.fc.weight.data.fill_(fcinit) # book the potential function - self.user_potential = fpot + self.user_potential = fpot def forward(self, x): ''' Compute the value of the wave function. diff --git a/schrodinet/wavefunction/wf_potential_2d.py b/schrodinet/wavefunction/wave_function_2d.py similarity index 96% rename from schrodinet/wavefunction/wf_potential_2d.py rename to schrodinet/wavefunction/wave_function_2d.py index 7d350e8..6933f16 100644 --- a/schrodinet/wavefunction/wf_potential_2d.py +++ b/schrodinet/wavefunction/wave_function_2d.py @@ -5,11 +5,11 @@ from schrodinet.wavefunction.rbf import RBF_Gaussian as RBF -class Potential2D(WaveFunction): +class WaveFunction2D(WaveFunction): def __init__(self, fpot, domain, ncenter, nelec=1, ndim=1, fcinit=0.1, sigma=1.): - super(Potential2D, self).__init__(nelec, ndim) + super(WaveFunction2D, self).__init__(nelec, ndim) # get the RBF centers if not isinstance(ncenter, list): diff --git a/tests/test_ho1d.py b/tests/test_ho1d.py index c5447df..2f7d91b 100644 --- a/tests/test_ho1d.py +++ b/tests/test_ho1d.py @@ -3,8 +3,8 @@ from schrodinet.sampler.metropolis import Metropolis from schrodinet.sampler.hamiltonian import Hamiltonian -from schrodinet.wavefunction.wf_potential import Potential -from schrodinet.solver.solver_potential import SolverPotential +from schrodinet.wavefunction.wave_function_1d import WaveFunction1D +from schrodinet.solver.solver import Solver import numpy as np import unittest @@ -26,8 +26,8 @@ def setUp(self): # wavefunction domain, ncenter = {'min': -5., 'max': 5.}, 11 - self.wf = Potential(pot_func, domain, ncenter, - fcinit='random', nelec=1, sigma=2.) + self.wf = WaveFunction1D(pot_func, domain, ncenter, + fcinit='random', nelec=1, sigma=2.) # sampler self.mh_sampler = Metropolis(nwalkers=1000, nstep=2000, @@ -48,9 +48,9 @@ def setUp(self): self.opt, step_size=100, gamma=0.75) # network - self.solver = SolverPotential(wf=self.wf, sampler=self.mh_sampler, - optimizer=self.opt, - scheduler=self.scheduler) + self.solver = Solver(wf=self.wf, sampler=self.mh_sampler, + optimizer=self.opt, + scheduler=self.scheduler) def test_single_point_metropolis_hasting_sampling(self):