Skip to content

Commit

Permalink
Move eval out of sampling loop
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Nov 19, 2024
1 parent c796f90 commit 2721d84
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/smalldiffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def samples(model : nn.Module,
xt : Optional[torch.FloatTensor] = None,
cond : Optional[torch.Tensor] = None,
accelerator: Optional[Accelerator] = None):
model.eval()
accelerator = accelerator or Accelerator()
xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] if xt is None else xt
if cond is not None:
assert cond.shape[0] == xt.shape[0], 'cond must have same shape as x!'
cond = cond.to(xt.device)
eps = None
for i, (sig, sig_prev) in enumerate(pairwise(sigmas)):
model.eval()
eps_prev, eps = eps, model.predict_eps_cfg(xt, sig.to(xt), cond, cfg_scale)
eps_av = eps * gam + eps_prev * (1-gam) if i > 0 else eps
sig_p = (sig_prev/sig**mu)**(1/(1-mu)) # sig_prev == sig**mu sig_p**(1-mu)
Expand Down

0 comments on commit 2721d84

Please sign in to comment.