Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added conditional training and sampling #2

Merged
merged 6 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/fashion_mnist_dit_cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torchvision import transforms as tf
from torchvision.datasets import FashionMNIST
from torchvision.utils import make_grid, save_image
from torch_ema import ExponentialMovingAverage as EMA
from tqdm import tqdm

from smalldiffusion import ScheduleDDPM, samples, training_loop, MappedDataset, DiT, CondEmbedderLabel

# Setup
accelerator = Accelerator()
dataset = FashionMNIST('datasets', train=True, download=True,
transform=tf.Compose([
tf.RandomHorizontalFlip(),
tf.ToTensor(),
tf.Lambda(lambda t: (t * 2) - 1)
]))
loader = DataLoader(dataset, batch_size=1024, shuffle=True)
schedule = ScheduleDDPM(beta_start=0.0001, beta_end=0.02, N=1000)
cond_embed = CondEmbedderLabel(32*6, 10, 0.1)
model = DiT(in_dim=28, channels=1,
patch_size=2, depth=6, head_dim=32, num_heads=6, mlp_ratio=4.0,
cond_embed=cond_embed)

# Train
trainer = training_loop(loader, model, schedule, epochs=300, lr=1e-3, conditional=True,
accelerator=accelerator)
ema = EMA(model.parameters(), decay=0.99)
ema.to(accelerator.device)
for ns in trainer:
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()

# Sample
with ema.average_parameters():
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6, batchsize=40,
cond=list(range(10))*4,
accelerator=accelerator)
save_image(((make_grid(x0) + 1)/2).clamp(0, 1), 'fashion_mnist_samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')
9 changes: 8 additions & 1 deletion src/smalldiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,12 @@
)

from .model import (
TimeInputMLP, ModelMixin, get_sigma_embeds, IdealDenoiser, DiT
ModelMixin,
Scaled, PredX0, PredV,
TimeInputMLP, IdealDenoiser,
get_sigma_embeds,
SigmaEmbedderSinCos,
CondEmbedderLabel,
)

from .model_dit import DiT
48 changes: 28 additions & 20 deletions src/smalldiffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from types import SimpleNamespace
from typing import Optional
from typing import Optional, Union, Tuple

class Schedule:
'''Diffusion noise schedules parameterized by sigma'''
Expand Down Expand Up @@ -73,12 +73,15 @@ def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02,
# Given a batch of data x0, returns:
# eps : i.i.d. normal with same shape as x0
# sigma: uniformly sampled from schedule, with shape Bx1x..x1 for broadcasting
def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule):
def generate_train_sample(x0: Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]],
schedule: Schedule, conditional: bool=False):
cond = x0[1] if conditional else None
x0 = x0[0] if conditional else x0
sigma = schedule.sample_batch(x0)
while len(sigma.shape) < len(x0.shape):
sigma = sigma.unsqueeze(-1)
eps = torch.randn_like(x0)
return sigma, eps
return x0, sigma, eps, cond

# Model objects
# Always called with (x, sigma):
Expand All @@ -87,20 +90,22 @@ def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule):
# Otherwise, x[i] will be paired with sigma[i] when calling model
# Have a `rand_input` method for generating random xt during sampling

def training_loop(loader : DataLoader,
model : nn.Module,
schedule : Schedule,
accelerator: Optional[Accelerator] = None,
epochs : int = 10000,
lr : float = 1e-3):
def training_loop(loader : DataLoader,
model : nn.Module,
schedule : Schedule,
accelerator : Optional[Accelerator] = None,
epochs : int = 10000,
lr : float = 1e-3,
conditional : bool = False):
accelerator = accelerator or Accelerator()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
for _ in (pbar := tqdm(range(epochs))):
for x0 in loader:
model.train()
optimizer.zero_grad()
sigma, eps = generate_train_sample(x0, schedule)
loss = model.get_loss(x0, sigma, eps)
x0, sigma, eps, cond = generate_train_sample(x0, schedule, conditional)
loss = model.get_loss(x0, sigma, eps, cond=cond)
yield SimpleNamespace(**locals()) # For extracting training statistics
accelerator.backward(loss)
optimizer.step()
Expand All @@ -114,19 +119,22 @@ def samples(model : nn.Module,
sigmas : torch.FloatTensor, # Iterable with N+1 values for N sampling steps
gam : float = 1., # Suggested to use gam >= 1
mu : float = 0., # Requires mu in [0, 1)
cfg_scale : int = 0., # 0 means no classifier-free guidance
batchsize : int = 1,
xt : Optional[torch.FloatTensor] = None,
accelerator: Optional[Accelerator] = None,
batchsize : int = 1):
cond : Optional[torch.Tensor] = None,
accelerator: Optional[Accelerator] = None):
accelerator = accelerator or Accelerator()
if xt is None:
xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0]
else:
batchsize = xt.shape[0]
xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] if xt is None else xt
if cond is not None:
assert cond.shape[0] == xt.shape[0], 'cond must have same shape as x!'
cond = cond.to(xt.device)
eps = None
for i, (sig, sig_prev) in enumerate(pairwise(sigmas)):
eps, eps_prev = model.predict_eps(xt, sig.to(xt)), eps
model.eval()
eps_prev, eps = eps, model.predict_eps_cfg(xt, sig.to(xt), cond, cfg_scale)
eps_av = eps * gam + eps_prev * (1-gam) if i > 0 else eps
sig_p = (sig_prev/sig**mu)**(1/(1-mu)) # sig_prev == sig**mu sig_p**(1-mu)
eta = (sig_prev**2 - sig_p**2).sqrt()
xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(batchsize).to(xt)
xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt)
yield xt
Loading
Loading