Skip to content

Commit

Permalink
Merge pull request #2 from yuanchenyang/conditional
Browse files Browse the repository at this point in the history
Added conditional training and sampling
  • Loading branch information
yuanchenyang authored Oct 11, 2024
2 parents 7a80531 + 0065536 commit 9959bbe
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 187 deletions.
42 changes: 42 additions & 0 deletions examples/fashion_mnist_dit_cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torchvision import transforms as tf
from torchvision.datasets import FashionMNIST
from torchvision.utils import make_grid, save_image
from torch_ema import ExponentialMovingAverage as EMA
from tqdm import tqdm

from smalldiffusion import ScheduleDDPM, samples, training_loop, MappedDataset, DiT, CondEmbedderLabel

# Setup
accelerator = Accelerator()
dataset = FashionMNIST('datasets', train=True, download=True,
transform=tf.Compose([
tf.RandomHorizontalFlip(),
tf.ToTensor(),
tf.Lambda(lambda t: (t * 2) - 1)
]))
loader = DataLoader(dataset, batch_size=1024, shuffle=True)
schedule = ScheduleDDPM(beta_start=0.0001, beta_end=0.02, N=1000)
cond_embed = CondEmbedderLabel(32*6, 10, 0.1)
model = DiT(in_dim=28, channels=1,
patch_size=2, depth=6, head_dim=32, num_heads=6, mlp_ratio=4.0,
cond_embed=cond_embed)

# Train
trainer = training_loop(loader, model, schedule, epochs=300, lr=1e-3, conditional=True,
accelerator=accelerator)
ema = EMA(model.parameters(), decay=0.99)
ema.to(accelerator.device)
for ns in trainer:
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()

# Sample
with ema.average_parameters():
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6, batchsize=40,
cond=list(range(10))*4,
accelerator=accelerator)
save_image(((make_grid(x0) + 1)/2).clamp(0, 1), 'fashion_mnist_samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')
9 changes: 8 additions & 1 deletion src/smalldiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,12 @@
)

from .model import (
TimeInputMLP, ModelMixin, get_sigma_embeds, IdealDenoiser, DiT
ModelMixin,
Scaled, PredX0, PredV,
TimeInputMLP, IdealDenoiser,
get_sigma_embeds,
SigmaEmbedderSinCos,
CondEmbedderLabel,
)

from .model_dit import DiT
48 changes: 28 additions & 20 deletions src/smalldiffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from types import SimpleNamespace
from typing import Optional
from typing import Optional, Union, Tuple

class Schedule:
'''Diffusion noise schedules parameterized by sigma'''
Expand Down Expand Up @@ -73,12 +73,15 @@ def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02,
# Given a batch of data x0, returns:
# eps : i.i.d. normal with same shape as x0
# sigma: uniformly sampled from schedule, with shape Bx1x..x1 for broadcasting
def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule):
def generate_train_sample(x0: Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]],
schedule: Schedule, conditional: bool=False):
cond = x0[1] if conditional else None
x0 = x0[0] if conditional else x0
sigma = schedule.sample_batch(x0)
while len(sigma.shape) < len(x0.shape):
sigma = sigma.unsqueeze(-1)
eps = torch.randn_like(x0)
return sigma, eps
return x0, sigma, eps, cond

# Model objects
# Always called with (x, sigma):
Expand All @@ -87,20 +90,22 @@ def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule):
# Otherwise, x[i] will be paired with sigma[i] when calling model
# Have a `rand_input` method for generating random xt during sampling

def training_loop(loader : DataLoader,
model : nn.Module,
schedule : Schedule,
accelerator: Optional[Accelerator] = None,
epochs : int = 10000,
lr : float = 1e-3):
def training_loop(loader : DataLoader,
model : nn.Module,
schedule : Schedule,
accelerator : Optional[Accelerator] = None,
epochs : int = 10000,
lr : float = 1e-3,
conditional : bool = False):
accelerator = accelerator or Accelerator()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
for _ in (pbar := tqdm(range(epochs))):
for x0 in loader:
model.train()
optimizer.zero_grad()
sigma, eps = generate_train_sample(x0, schedule)
loss = model.get_loss(x0, sigma, eps)
x0, sigma, eps, cond = generate_train_sample(x0, schedule, conditional)
loss = model.get_loss(x0, sigma, eps, cond=cond)
yield SimpleNamespace(**locals()) # For extracting training statistics
accelerator.backward(loss)
optimizer.step()
Expand All @@ -114,19 +119,22 @@ def samples(model : nn.Module,
sigmas : torch.FloatTensor, # Iterable with N+1 values for N sampling steps
gam : float = 1., # Suggested to use gam >= 1
mu : float = 0., # Requires mu in [0, 1)
cfg_scale : int = 0., # 0 means no classifier-free guidance
batchsize : int = 1,
xt : Optional[torch.FloatTensor] = None,
accelerator: Optional[Accelerator] = None,
batchsize : int = 1):
cond : Optional[torch.Tensor] = None,
accelerator: Optional[Accelerator] = None):
accelerator = accelerator or Accelerator()
if xt is None:
xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0]
else:
batchsize = xt.shape[0]
xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] if xt is None else xt
if cond is not None:
assert cond.shape[0] == xt.shape[0], 'cond must have same shape as x!'
cond = cond.to(xt.device)
eps = None
for i, (sig, sig_prev) in enumerate(pairwise(sigmas)):
eps, eps_prev = model.predict_eps(xt, sig.to(xt)), eps
model.eval()
eps_prev, eps = eps, model.predict_eps_cfg(xt, sig.to(xt), cond, cfg_scale)
eps_av = eps * gam + eps_prev * (1-gam) if i > 0 else eps
sig_p = (sig_prev/sig**mu)**(1/(1-mu)) # sig_prev == sig**mu sig_p**(1-mu)
eta = (sig_prev**2 - sig_p**2).sqrt()
xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(batchsize).to(xt)
xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt)
yield xt
Loading

0 comments on commit 9959bbe

Please sign in to comment.