From b7abc02e35ec28b767fdcec2f764193bc68ca217 Mon Sep 17 00:00:00 2001 From: Chenyang Yuan Date: Wed, 7 Feb 2024 00:05:33 -0500 Subject: [PATCH] Simplified sampling code --- src/smalldiffusion/diffusion.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/smalldiffusion/diffusion.py b/src/smalldiffusion/diffusion.py index 5074d2b..776c0a5 100644 --- a/src/smalldiffusion/diffusion.py +++ b/src/smalldiffusion/diffusion.py @@ -32,10 +32,11 @@ def sample_sigmas(self, steps: int) -> torch.FloatTensor: .round().astype(np.int64) - 1) return self[indices + [0]] - def sample_batch(self, batchsize: int) -> torch.FloatTensor: + def sample_batch(self, x0: torch.FloatTensor) -> torch.FloatTensor: '''Called during training to get a batch of randomly sampled sigma values ''' - return self[torch.randint(len(self), (batchsize,))] + batchsize = x0.shape[0] + return self[torch.randint(len(self), (batchsize,))].to(x0) def sigmas_from_betas(betas: torch.FloatTensor): return (1/torch.cumprod(1.0 - betas, dim=0) - 1).sqrt() @@ -59,10 +60,10 @@ def __init__(self, N: int=1000, beta_start: float=0.00085, beta_end: float=0.012 # eps : i.i.d. normal with same shape as x0 # sigma: uniformly sampled from schedule, with shape Bx1x..x1 for broadcasting def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule): - sigma = schedule.sample_batch(x0.shape[0]).to(x0) + sigma = schedule.sample_batch(x0) while len(sigma.shape) < len(x0.shape): sigma = sigma.unsqueeze(-1) - eps = torch.randn(x0.shape).to(x0) + eps = torch.randn_like(x0) return sigma, eps # Model objects