Skip to content

Commit

Permalink
Simplified sampling code
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Feb 7, 2024
1 parent 3c05481 commit b7abc02
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/smalldiffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def sample_sigmas(self, steps: int) -> torch.FloatTensor:
.round().astype(np.int64) - 1)
return self[indices + [0]]

def sample_batch(self, batchsize: int) -> torch.FloatTensor:
def sample_batch(self, x0: torch.FloatTensor) -> torch.FloatTensor:
'''Called during training to get a batch of randomly sampled sigma values
'''
return self[torch.randint(len(self), (batchsize,))]
batchsize = x0.shape[0]
return self[torch.randint(len(self), (batchsize,))].to(x0)

def sigmas_from_betas(betas: torch.FloatTensor):
return (1/torch.cumprod(1.0 - betas, dim=0) - 1).sqrt()
Expand All @@ -59,10 +60,10 @@ def __init__(self, N: int=1000, beta_start: float=0.00085, beta_end: float=0.012
# eps : i.i.d. normal with same shape as x0
# sigma: uniformly sampled from schedule, with shape Bx1x..x1 for broadcasting
def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule):
sigma = schedule.sample_batch(x0.shape[0]).to(x0)
sigma = schedule.sample_batch(x0)
while len(sigma.shape) < len(x0.shape):
sigma = sigma.unsqueeze(-1)
eps = torch.randn(x0.shape).to(x0)
eps = torch.randn_like(x0)
return sigma, eps

# Model objects
Expand Down

0 comments on commit b7abc02

Please sign in to comment.