diff --git a/examples/fashion_mnist_dit_cond.py b/examples/fashion_mnist_dit_cond.py new file mode 100644 index 0000000..8eb025d --- /dev/null +++ b/examples/fashion_mnist_dit_cond.py @@ -0,0 +1,42 @@ +import torch +from accelerate import Accelerator +from torch.utils.data import DataLoader +from torchvision import transforms as tf +from torchvision.datasets import FashionMNIST +from torchvision.utils import make_grid, save_image +from torch_ema import ExponentialMovingAverage as EMA +from tqdm import tqdm + +from smalldiffusion import ScheduleDDPM, samples, training_loop, MappedDataset, DiT, CondEmbedderLabel + +# Setup +accelerator = Accelerator() +dataset = FashionMNIST('datasets', train=True, download=True, + transform=tf.Compose([ + tf.RandomHorizontalFlip(), + tf.ToTensor(), + tf.Lambda(lambda t: (t * 2) - 1) + ])) +loader = DataLoader(dataset, batch_size=1024, shuffle=True) +schedule = ScheduleDDPM(beta_start=0.0001, beta_end=0.02, N=1000) +cond_embed = CondEmbedderLabel(32*6, 10, 0.1) +model = DiT(in_dim=28, channels=1, + patch_size=2, depth=6, head_dim=32, num_heads=6, mlp_ratio=4.0, + cond_embed=cond_embed) + +# Train +trainer = training_loop(loader, model, schedule, epochs=300, lr=1e-3, conditional=True, + accelerator=accelerator) +ema = EMA(model.parameters(), decay=0.99) +ema.to(accelerator.device) +for ns in trainer: + ns.pbar.set_description(f'Loss={ns.loss.item():.5}') + ema.update() + +# Sample +with ema.average_parameters(): + *xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6, batchsize=40, + cond=list(range(10))*4, + accelerator=accelerator) + save_image(((make_grid(x0) + 1)/2).clamp(0, 1), 'fashion_mnist_samples.png') + torch.save(model.state_dict(), 'checkpoint.pth') diff --git a/src/smalldiffusion/__init__.py b/src/smalldiffusion/__init__.py index f618525..3e3e04b 100644 --- a/src/smalldiffusion/__init__.py +++ b/src/smalldiffusion/__init__.py @@ -8,5 +8,12 @@ ) from .model import ( - TimeInputMLP, ModelMixin, get_sigma_embeds, IdealDenoiser, DiT + ModelMixin, + Scaled, PredX0, PredV, + TimeInputMLP, IdealDenoiser, + get_sigma_embeds, + SigmaEmbedderSinCos, + CondEmbedderLabel, ) + +from .model_dit import DiT diff --git a/src/smalldiffusion/diffusion.py b/src/smalldiffusion/diffusion.py index 044c310..9e1e3b1 100644 --- a/src/smalldiffusion/diffusion.py +++ b/src/smalldiffusion/diffusion.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm from types import SimpleNamespace -from typing import Optional +from typing import Optional, Union, Tuple class Schedule: '''Diffusion noise schedules parameterized by sigma''' @@ -73,12 +73,15 @@ def __init__(self, N: int=1000, beta_start: float=0.0001, beta_end: float=0.02, # Given a batch of data x0, returns: # eps : i.i.d. normal with same shape as x0 # sigma: uniformly sampled from schedule, with shape Bx1x..x1 for broadcasting -def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule): +def generate_train_sample(x0: Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]], + schedule: Schedule, conditional: bool=False): + cond = x0[1] if conditional else None + x0 = x0[0] if conditional else x0 sigma = schedule.sample_batch(x0) while len(sigma.shape) < len(x0.shape): sigma = sigma.unsqueeze(-1) eps = torch.randn_like(x0) - return sigma, eps + return x0, sigma, eps, cond # Model objects # Always called with (x, sigma): @@ -87,20 +90,22 @@ def generate_train_sample(x0: torch.FloatTensor, schedule: Schedule): # Otherwise, x[i] will be paired with sigma[i] when calling model # Have a `rand_input` method for generating random xt during sampling -def training_loop(loader : DataLoader, - model : nn.Module, - schedule : Schedule, - accelerator: Optional[Accelerator] = None, - epochs : int = 10000, - lr : float = 1e-3): +def training_loop(loader : DataLoader, + model : nn.Module, + schedule : Schedule, + accelerator : Optional[Accelerator] = None, + epochs : int = 10000, + lr : float = 1e-3, + conditional : bool = False): accelerator = accelerator or Accelerator() - optimizer = torch.optim.Adam(model.parameters(), lr=lr) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) model, optimizer, loader = accelerator.prepare(model, optimizer, loader) for _ in (pbar := tqdm(range(epochs))): for x0 in loader: + model.train() optimizer.zero_grad() - sigma, eps = generate_train_sample(x0, schedule) - loss = model.get_loss(x0, sigma, eps) + x0, sigma, eps, cond = generate_train_sample(x0, schedule, conditional) + loss = model.get_loss(x0, sigma, eps, cond=cond) yield SimpleNamespace(**locals()) # For extracting training statistics accelerator.backward(loss) optimizer.step() @@ -114,19 +119,22 @@ def samples(model : nn.Module, sigmas : torch.FloatTensor, # Iterable with N+1 values for N sampling steps gam : float = 1., # Suggested to use gam >= 1 mu : float = 0., # Requires mu in [0, 1) + cfg_scale : int = 0., # 0 means no classifier-free guidance + batchsize : int = 1, xt : Optional[torch.FloatTensor] = None, - accelerator: Optional[Accelerator] = None, - batchsize : int = 1): + cond : Optional[torch.Tensor] = None, + accelerator: Optional[Accelerator] = None): accelerator = accelerator or Accelerator() - if xt is None: - xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] - else: - batchsize = xt.shape[0] + xt = model.rand_input(batchsize).to(accelerator.device) * sigmas[0] if xt is None else xt + if cond is not None: + assert cond.shape[0] == xt.shape[0], 'cond must have same shape as x!' + cond = cond.to(xt.device) eps = None for i, (sig, sig_prev) in enumerate(pairwise(sigmas)): - eps, eps_prev = model.predict_eps(xt, sig.to(xt)), eps + model.eval() + eps_prev, eps = eps, model.predict_eps_cfg(xt, sig.to(xt), cond, cfg_scale) eps_av = eps * gam + eps_prev * (1-gam) if i > 0 else eps sig_p = (sig_prev/sig**mu)**(1/(1-mu)) # sig_prev == sig**mu sig_p**(1-mu) eta = (sig_prev**2 - sig_p**2).sqrt() - xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(batchsize).to(xt) + xt = xt - (sig - sig_p) * eps_av + eta * model.rand_input(xt.shape[0]).to(xt) yield xt diff --git a/src/smalldiffusion/model.py b/src/smalldiffusion/model.py index 5aae40b..df8089b 100644 --- a/src/smalldiffusion/model.py +++ b/src/smalldiffusion/model.py @@ -1,6 +1,5 @@ import math import torch -import numpy as np import torch.nn.functional as F from torch import nn from einops import rearrange, repeat @@ -26,11 +25,21 @@ def rand_input(self, batchsize): return torch.randn((batchsize,) + self.input_dims) # Currently predicts eps, override following methods to predict, for example, x0 - def get_loss(self, x0, sigma, eps): - return nn.MSELoss()(eps, self(x0 + sigma * eps, sigma)) - - def predict_eps(self, x, sigma): - return self(x, sigma) + def get_loss(self, x0, sigma, eps, cond=None, loss=nn.MSELoss): + return loss()(eps, self(x0 + sigma * eps, sigma, cond=cond)) + + def predict_eps(self, x, sigma, cond=None): + return self(x, sigma, cond=cond) + + def predict_eps_cfg(self, x, sigma, cond, cfg_scale): + if cond is None: + return self.predict_eps(x, sigma) + assert sigma.shape == tuple(), 'CFG sampling only supports singleton sigma!' + uncond = torch.full_like(cond, self.cond_embed.null_cond) # (B,) + eps_cond, eps_uncond = self.predict_eps( # (B,), (B,) + torch.cat([x, x]), sigma, torch.cat([cond, uncond]) # (2B,) + ).chunk(2) + return eps_cond + cfg_scale * (eps_cond - eps_uncond) ## Modifiers for models, such as including scaling or changing model predictions @@ -39,28 +48,28 @@ def alpha(sigma): # Scale model input so that its norm stays constant for all sigma def Scaled(cls: ModelMixin): - def forward(self, x, sigma): - return cls.forward(self, x * alpha(sigma).sqrt(), sigma) + def forward(self, x, sigma, cond=None): + return cls.forward(self, x * alpha(sigma).sqrt(), sigma, cond=cond) return type(cls.__name__ + 'Scaled', (cls,), dict(forward=forward)) # Train model to predict x0 instead of eps def PredX0(cls: ModelMixin): - def get_loss(self, x0, sigma, eps): - return nn.MSELoss()(x0, self(x0 + sigma * eps, sigma)) - def predict_eps(self, x, sigma): - x0_hat = self(x, sigma) + def get_loss(self, x0, sigma, eps, cond=None, loss=nn.MSELoss): + return loss()(x0, self(x0 + sigma * eps, sigma, cond=cond)) + def predict_eps(self, x, sigma, cond=None): + x0_hat = self(x, sigma, cond=cond) return (x - x0_hat)/sigma return type(cls.__name__ + 'PredX0', (cls,), dict(get_loss=get_loss, predict_eps=predict_eps)) # Train model to predict v (https://arxiv.org/pdf/2202.00512.pdf) instead of eps def PredV(cls: ModelMixin): - def get_loss(self, x0, sigma, eps): + def get_loss(self, x0, sigma, eps, cond=None, loss=nn.MSELoss): xt = x0 + sigma * eps v = alpha(sigma).sqrt() * eps - (1-alpha(sigma)).sqrt() * x0 - return nn.MSELoss()(v, self(xt, sigma)) - def predict_eps(self, x, sigma): - v_hat = self(x, sigma) + return loss()(v, self(xt, sigma, cond=cond)) + def predict_eps(self, x, sigma, cond=None): + v_hat = self(x, sigma, cond=cond) return alpha(sigma).sqrt() * (v_hat + (1-alpha(sigma)).sqrt() * x) return type(cls.__name__ + 'PredV', (cls,), dict(get_loss=get_loss, predict_eps=predict_eps)) @@ -79,7 +88,7 @@ def __init__(self, dim=2, hidden_dims=(16,128,256,128,16)): self.net = nn.Sequential(*layers) self.input_dims = (dim,) - def forward(self, x, sigma): + def forward(self, x, sigma, cond=None): # x shape: b x dim # sigma shape: b x 1 or scalar sigma_embeds = get_sigma_embeds(x.shape[0], sigma.squeeze()) # shape: b x 2 @@ -93,43 +102,37 @@ def sq_norm(M, k): # M: b x n --(norm)--> b --(repeat)--> b x k return (torch.norm(M, dim=1)**2).unsqueeze(1).repeat(1,k) -class IdealDenoiser(ModelMixin): +class IdealDenoiser(nn.Module, ModelMixin): def __init__(self, dataset: torch.utils.data.Dataset): + super().__init__() self.data = torch.stack([dataset[i] for i in range(len(dataset))]) self.input_dims = self.data.shape[1:] - def __call__(self, x, sigma): - assert sigma.shape == tuple(), 'Only singleton sigma supported' - data = self.data.to(x) + def forward(self, x, sigma, cond=None): + data = self.data.to(x) # shape: db x c1 x ... x cn x_flat = x.flatten(start_dim=1) d_flat = data.flatten(start_dim=1) xb, xr = x_flat.shape db, dr = d_flat.shape assert xr == dr, 'Input x must have same dimension as data!' - # ||x - x0||^2 ,shape xb x db - sq_diffs = sq_norm(x_flat, db) + sq_norm(d_flat, xb).T - 2 * x_flat @ d_flat.T - weights = torch.nn.functional.softmax(-sq_diffs/2/sigma**2, dim=1) - return (x - torch.einsum('ij,j...->i...', weights, data))/sigma - - -## Diffusion transformer - -class PatchEmbed(nn.Module): - def __init__(self, patch_size=16, channels=3, embed_dim=768, bias=True): - super().__init__() - self.proj = nn.Conv2d(channels, embed_dim, stride=patch_size, kernel_size=patch_size, bias=bias) - self.init() - - def init(self): # Init like nn.Linear - w = self.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.proj.bias, 0) - - def forward(self, x): - return rearrange(self.proj(x), 'b c h w -> b (h w) c') + assert sigma.shape == tuple() or sigma.shape[0] == xb, \ + f'sigma must be singleton or have same batch dimension as x! {sigma.shape}' + # sq_diffs: ||x - x0||^2 + sq_diffs = sq_norm(x_flat, db).T + sq_norm(d_flat, xb) - 2 * d_flat @ x_flat.T # shape: db x xb + weights = torch.nn.functional.softmax(-sq_diffs/2/sigma.squeeze()**2, dim=0) # shape: db x xb + eps = torch.einsum('ij,i...->j...', weights, data) # shape: xb x c1 x ... x cn + return (x - eps) / sigma + +## Common functions for other models + +class CondSequential(nn.Sequential): + def forward(self, x, cond): + for module in self._modules.values(): + x = module(x, cond) + return x class Attention(nn.Module): - def __init__(self, head_dim, num_heads=8, qkv_bias=False, norm_layer=nn.LayerNorm): + def __init__(self, head_dim, num_heads=8, qkv_bias=False): super().__init__() self.num_heads = num_heads self.head_dim = head_dim @@ -146,121 +149,20 @@ def forward(self, x): 'b h n k -> b n (h k)') return self.proj(x) -class Modulation(nn.Module): - def __init__(self, dim, n): - super().__init__() - self.n = n - self.proj = nn.Sequential(nn.SiLU(), nn.Linear(dim, n * dim, bias=True)) - nn.init.constant_(self.proj[-1].weight, 0) - nn.init.constant_(self.proj[-1].bias, 0) - - def forward(self, y): - return [m.unsqueeze(1) for m in self.proj(y).chunk(self.n, dim=1)] - -class ModulatedLayerNorm(nn.LayerNorm): - def __init__(self, dim, **kwargs): - super().__init__(dim, **kwargs) - self.modulation = Modulation(dim, 2) - def forward(self, x, y): - scale, shift = self.modulation(y) - return super().forward(x) * (1 + scale) + shift - -class DiTBlock(nn.Module): - def __init__(self, head_dim, num_heads, mlp_ratio=4.0): - super().__init__() - dim = head_dim * num_heads - mlp_hidden_dim = int(dim * mlp_ratio) - self.norm1 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.attn = Attention(head_dim, num_heads=num_heads, qkv_bias=True) - self.norm2 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.mlp = nn.Sequential( - nn.Linear(dim, mlp_hidden_dim, bias=True), - nn.GELU(approximate="tanh"), - nn.Linear(mlp_hidden_dim, dim, bias=True), - ) - self.scale_modulation = Modulation(dim, 2) - - def forward(self, x, y): - # (B, N, D), (B, D) -> (B, N, D) - # N = H * W / patch_size**2, D = num_heads * head_dim - gate_msa, gate_mlp = self.scale_modulation(y) - x = x + gate_msa * self.attn(self.norm1(x, y)) - x = x + gate_mlp * self.mlp(self.norm2(x, y)) - return x - -def get_pos_embed(in_dim, patch_size, dim, N=10000): - n = in_dim // patch_size # Number of patches per side - assert dim % 4 == 0, 'Embedding dimension must be multiple of 4!' - omega = 1/N**np.linspace(0, 1, dim // 4, endpoint=False) # [dim/4] - freqs = np.outer(np.arange(n), omega) # [n, dim/4] - embeds = repeat(np.stack([np.sin(freqs), np.cos(freqs)]), - ' b n d -> b n k d', k=n) # [2, n, n, dim/4] - embeds_2d = np.concatenate([ - rearrange(embeds, 'b n k d -> (k n) (b d)'), # [n*n, dim/2] - rearrange(embeds, 'b n k d -> (n k) (b d)'), # [n*n, dim/2] - ], axis=1) # [n*n, dim] - return nn.Parameter(torch.tensor(embeds_2d).float().unsqueeze(0), # [1, n*n, dim] - requires_grad=False) - -class DiT(nn.Module, ModelMixin): - def __init__(self, in_dim=32, channels=3, patch_size=2, depth=12, - head_dim=64, num_heads=6, mlp_ratio=4.0, sig_embed_factor=0.5, - sig_embed_class=None): +# Embedding table for conditioning on labels assumed to be in [0, num_classes), +# unconditional label encoded as: num_classes +class CondEmbedderLabel(nn.Module): + def __init__(self, hidden_size, num_classes, dropout_prob=0.1): super().__init__() - self.in_dim = in_dim - self.channels = channels - self.patch_size = patch_size - self.input_dims = (channels, in_dim, in_dim) - - dim = head_dim * num_heads - - self.pos_embed = get_pos_embed(in_dim, patch_size, dim) - self.x_embed = PatchEmbed(patch_size, channels, dim, bias=True) - self.sig_embed = (sig_embed_class or SigmaEmbedderSinCos)( - dim, scaling_factor=sig_embed_factor - ) - - self.blocks = nn.ModuleList([ - DiTBlock(head_dim, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) - ]) - - self.final_norm = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.final_linear = nn.Linear(dim, patch_size**2 * channels) - self.init() - - def init(self): - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - self.apply(_basic_init) - - # Initialize sigma embedding MLP: - nn.init.normal_(self.sig_embed.mlp[0].weight, std=0.02) - nn.init.normal_(self.sig_embed.mlp[2].weight, std=0.02) - - # Zero-out output layers: - nn.init.constant_(self.final_linear.weight, 0) - nn.init.constant_(self.final_linear.bias, 0) - - def unpatchify(self, x): - # (B, N, patchsize**2 * channels) -> (B, channels, H, W) - patches = self.in_dim // self.patch_size - return rearrange(x, 'b (ph pw) (psh psw c) -> b c (ph psh) (pw psw)', - ph=patches, pw=patches, - psh=self.patch_size, psw=self.patch_size) - - def forward(self, x, sigma): - # (B, C, H, W), Union[(B, 1, 1, 1), ()] -> (B, C, H, W) - # N = num_patches, D = dim = head_dim * num_heads - x = self.x_embed(x) + self.pos_embed # (B, N, D) - y = self.sig_embed(x.shape[0], sigma.squeeze()) # (B, D) - for block in self.blocks: - x = block(x, y) # (B, N, D) - x = self.final_linear(self.final_norm(x, y)) # (B, N, patchsize**2 * channels) - return self.unpatchify(x) + self.embeddings = nn.Embedding(num_classes + 1, hidden_size) + self.null_cond = num_classes + self.dropout_prob = dropout_prob + + def forward(self, labels): # (B,) -> (B, D) + if self.training: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + labels = torch.where(drop_ids, self.null_cond, labels) + return self.embeddings(labels) # A simple embedding that works just as well as usual sinusoidal embedding class SigmaEmbedderSinCos(nn.Module): diff --git a/src/smalldiffusion/model_dit.py b/src/smalldiffusion/model_dit.py new file mode 100644 index 0000000..3f7bae1 --- /dev/null +++ b/src/smalldiffusion/model_dit.py @@ -0,0 +1,143 @@ +import torch +import numpy as np +from torch import nn +from einops import rearrange, repeat +from .model import ModelMixin, Attention, SigmaEmbedderSinCos, CondSequential + +## Diffusion transformer + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=16, channels=3, embed_dim=768, bias=True): + super().__init__() + self.proj = nn.Conv2d(channels, embed_dim, stride=patch_size, kernel_size=patch_size, bias=bias) + self.init() + + def init(self): # Init like nn.Linear + w = self.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.proj.bias, 0) + + def forward(self, x): + return rearrange(self.proj(x), 'b c h w -> b (h w) c') + +class Modulation(nn.Module): + def __init__(self, dim, n): + super().__init__() + self.n = n + self.proj = nn.Sequential(nn.SiLU(), nn.Linear(dim, n * dim, bias=True)) + nn.init.constant_(self.proj[-1].weight, 0) + nn.init.constant_(self.proj[-1].bias, 0) + + def forward(self, y): + return [m.unsqueeze(1) for m in self.proj(y).chunk(self.n, dim=1)] + +class ModulatedLayerNorm(nn.LayerNorm): + def __init__(self, dim, **kwargs): + super().__init__(dim, **kwargs) + self.modulation = Modulation(dim, 2) + def forward(self, x, y): + scale, shift = self.modulation(y) + return super().forward(x) * (1 + scale) + shift + +class DiTBlock(nn.Module): + def __init__(self, head_dim, num_heads, mlp_ratio=4.0): + super().__init__() + dim = head_dim * num_heads + mlp_hidden_dim = int(dim * mlp_ratio) + self.norm1 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.attn = Attention(head_dim, num_heads=num_heads, qkv_bias=True) + self.norm2 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, dim, bias=True), + ) + self.scale_modulation = Modulation(dim, 2) + + def forward(self, x, y): + # (B, N, D), (B, D) -> (B, N, D) + # N = H * W / patch_size**2, D = num_heads * head_dim + gate_msa, gate_mlp = self.scale_modulation(y) + x = x + gate_msa * self.attn(self.norm1(x, y)) + x = x + gate_mlp * self.mlp(self.norm2(x, y)) + return x + +def get_pos_embed(in_dim, patch_size, dim, N=10000): + n = in_dim // patch_size # Number of patches per side + assert dim % 4 == 0, 'Embedding dimension must be multiple of 4!' + omega = 1/N**np.linspace(0, 1, dim // 4, endpoint=False) # [dim/4] + freqs = np.outer(np.arange(n), omega) # [n, dim/4] + embeds = repeat(np.stack([np.sin(freqs), np.cos(freqs)]), + ' b n d -> b n k d', k=n) # [2, n, n, dim/4] + embeds_2d = np.concatenate([ + rearrange(embeds, 'b n k d -> (k n) (b d)'), # [n*n, dim/2] + rearrange(embeds, 'b n k d -> (n k) (b d)'), # [n*n, dim/2] + ], axis=1) # [n*n, dim] + return nn.Parameter(torch.tensor(embeds_2d).float().unsqueeze(0), # [1, n*n, dim] + requires_grad=False) + +class DiT(nn.Module, ModelMixin): + def __init__(self, in_dim=32, channels=3, patch_size=2, depth=12, + head_dim=64, num_heads=6, mlp_ratio=4.0, + sig_embed_class=None, sig_embed_factor=0.5, + cond_embed=None): + super().__init__() + self.in_dim = in_dim + self.channels = channels + self.patch_size = patch_size + self.input_dims = (channels, in_dim, in_dim) + + dim = head_dim * num_heads + + self.pos_embed = get_pos_embed(in_dim, patch_size, dim) + self.x_embed = PatchEmbed(patch_size, channels, dim, bias=True) + self.sig_embed = (sig_embed_class or SigmaEmbedderSinCos)( + dim, scaling_factor=sig_embed_factor + ) + self.cond_embed = cond_embed + + self.blocks = CondSequential(*[ + DiTBlock(head_dim, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + + self.final_norm = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.final_linear = nn.Linear(dim, patch_size**2 * channels) + self.init() + + def init(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize sigma embedding MLP: + nn.init.normal_(self.sig_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.sig_embed.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_linear.weight, 0) + nn.init.constant_(self.final_linear.bias, 0) + + def unpatchify(self, x): + # (B, N, patchsize**2 * channels) -> (B, channels, H, W) + patches = self.in_dim // self.patch_size + return rearrange(x, 'b (ph pw) (psh psw c) -> b c (ph psh) (pw psw)', + ph=patches, pw=patches, + psh=self.patch_size, psw=self.patch_size) + + def forward(self, x, sigma, cond=None): + # x: (B, C, H, W), sigma: Union[(B, 1, 1, 1), ()], cond: (B, *) + # returns: (B, C, H, W) + # N = num_patches, D = dim = head_dim * num_heads + x = self.x_embed(x) + self.pos_embed # (B, N, D) + y = self.sig_embed(x.shape[0], sigma.squeeze()) # (B, D) + if self.cond_embed is not None: + assert cond is not None and x.shape[0] == cond.shape[0], \ + 'Conditioning must have same batches as x!' + y += self.cond_embed(cond) # (B, D) + x = self.blocks(x, y) # (B, N, D) + x = self.final_linear(self.final_norm(x, y)) # (B, N, patchsize**2 * channels) + return self.unpatchify(x) diff --git a/tests/test_diffusion.py b/tests/test_diffusion.py index a79f1b3..abf24b3 100644 --- a/tests/test_diffusion.py +++ b/tests/test_diffusion.py @@ -10,11 +10,12 @@ def get_hf_sigmas(scheduler): return (1/scheduler.alphas_cumprod - 1).sqrt() -class DummyModel(ModelMixin): +class DummyModel(torch.nn.Module, ModelMixin): def __init__(self, dims): + super().__init__() self.input_dims = dims - def __call__(self, x, sigma): + def __call__(self, x, sigma, cond=None): gen = torch.Generator().manual_seed(int(sigma * 100000)) return torch.randn((x.shape[0],) + self.input_dims, generator=gen) @@ -184,11 +185,42 @@ def test_swissroll(self): accelerator=accelerator) self.assertEqual(sample.shape, (B//2, 2)) +class TestIdeal(unittest.TestCase, TensorTest): + # Test that ideal deoiser batching works + def test_ideal(self): + for N in [1, 10, 99]: + loader = DataLoader(Swissroll(np.pi/2, 5*np.pi, 30), batch_size=2000) + sigmas = torch.linspace(1, 2, N) + idd = IdealDenoiser(loader.dataset) + x0 = idd.rand_input(N) + batched_output = idd(x0, sigmas.unsqueeze(1)) + singleton_output = torch.cat([idd(x0i.unsqueeze(0), s) for x0i, s in zip(x0, sigmas)]) + self.assertAlmostEqualTensors(batched_output, singleton_output, tol=1e-6) + +# Just testing that model creation and forward pass works class TestDiT(unittest.TestCase): - def test_basic_setup(self): - # Just testing that model creation and forward pass works - model = DiT(in_dim=16, channels=3, patch_size=2, depth=4, head_dim=32, num_heads=6) - x = torch.randn(10, 3, 16, 16) - sigma = torch.tensor(1) - y = model(x, sigma) - self.assertEqual(y.shape, x.shape) + def setUp(self): + self.modifiers = [ + Scaled, PredX0, PredV, + lambda x: x, + lambda x: Scaled(PredX0(x)), + lambda x: Scaled(PredV(x)) + ] + + def test_uncond(self): + for modifier in self.modifiers: + model = modifier(DiT)(in_dim=16, channels=3, patch_size=2, depth=4, head_dim=32, num_heads=6) + x = torch.randn(10, 3, 16, 16) + sigma = torch.tensor(1) + y = model.predict_eps(x, sigma) + self.assertEqual(y.shape, x.shape) + + def test_cond(self): + for modifier in self.modifiers: + model = modifier(DiT)(in_dim=16, channels=3, patch_size=2, depth=4, head_dim=32, num_heads=6, + cond_embed=CondEmbedderLabel(32*6, 10)) + x = torch.randn(10, 3, 16, 16) + sigma = torch.tensor(1) + labels = torch.tensor([1,2,3,4,5] + [10]*5) + y = model.predict_eps_cfg(x, sigma, cond=labels, cfg_scale=4.0) + self.assertEqual(y.shape, x.shape)