Skip to content

Commit

Permalink
Updated conditioning, refactored DiT
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Oct 11, 2024
1 parent 9325dc1 commit 3eb69c8
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 155 deletions.
12 changes: 8 additions & 4 deletions src/smalldiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
)

from .model import (
ModelMixin, Scaled, PredX0, PredV,
TimeInputMLP, IdealDenoiser, DiT,
get_sigma_embeds, SigmaEmbedderSinCos,
CondEmbedderLabel
ModelMixin,
Scaled, PredX0, PredV,
TimeInputMLP, IdealDenoiser,
get_sigma_embeds,
SigmaEmbedderSinCos,
CondEmbedderLabel,
)

from .model_dit import DiT
164 changes: 14 additions & 150 deletions src/smalldiffusion/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
Expand All @@ -26,8 +25,8 @@ def rand_input(self, batchsize):
return torch.randn((batchsize,) + self.input_dims)

# Currently predicts eps, override following methods to predict, for example, x0
def get_loss(self, x0, sigma, eps, cond=None):
return nn.MSELoss()(eps, self(x0 + sigma * eps, sigma, cond=cond))
def get_loss(self, x0, sigma, eps, cond=None, loss=nn.MSELoss):
return loss()(eps, self(x0 + sigma * eps, sigma, cond=cond))

def predict_eps(self, x, sigma, cond=None):
return self(x, sigma, cond=cond)
Expand Down Expand Up @@ -55,8 +54,8 @@ def forward(self, x, sigma, cond=None):

# Train model to predict x0 instead of eps
def PredX0(cls: ModelMixin):
def get_loss(self, x0, sigma, eps, cond=None):
return nn.MSELoss()(x0, self(x0 + sigma * eps, sigma, cond=cond))
def get_loss(self, x0, sigma, eps, cond=None, loss=nn.MSELoss):
return loss()(x0, self(x0 + sigma * eps, sigma, cond=cond))
def predict_eps(self, x, sigma, cond=None):
x0_hat = self(x, sigma, cond=cond)
return (x - x0_hat)/sigma
Expand All @@ -65,10 +64,10 @@ def predict_eps(self, x, sigma, cond=None):

# Train model to predict v (https://arxiv.org/pdf/2202.00512.pdf) instead of eps
def PredV(cls: ModelMixin):
def get_loss(self, x0, sigma, eps, cond=None):
def get_loss(self, x0, sigma, eps, cond=None, loss=nn.MSELoss):
xt = x0 + sigma * eps
v = alpha(sigma).sqrt() * eps - (1-alpha(sigma)).sqrt() * x0
return nn.MSELoss()(v, self(xt, sigma, cond=cond))
return loss()(v, self(xt, sigma, cond=cond))
def predict_eps(self, x, sigma, cond=None):
v_hat = self(x, sigma, cond=cond)
return alpha(sigma).sqrt() * (v_hat + (1-alpha(sigma)).sqrt() * x)
Expand Down Expand Up @@ -124,25 +123,16 @@ def forward(self, x, sigma, cond=None):
eps = torch.einsum('ij,i...->j...', weights, data) # shape: xb x c1 x ... x cn
return (x - eps) / sigma

## Common functions for other models

## Diffusion transformer

class PatchEmbed(nn.Module):
def __init__(self, patch_size=16, channels=3, embed_dim=768, bias=True):
super().__init__()
self.proj = nn.Conv2d(channels, embed_dim, stride=patch_size, kernel_size=patch_size, bias=bias)
self.init()

def init(self): # Init like nn.Linear
w = self.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.proj.bias, 0)

def forward(self, x):
return rearrange(self.proj(x), 'b c h w -> b (h w) c')
class CondSequential(nn.Sequential):
def forward(self, x, cond):
for module in self._modules.values():
x = module(x, cond)
return x

class Attention(nn.Module):
def __init__(self, head_dim, num_heads=8, qkv_bias=False, norm_layer=nn.LayerNorm):
def __init__(self, head_dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
Expand All @@ -159,136 +149,10 @@ def forward(self, x):
'b h n k -> b n (h k)')
return self.proj(x)

class Modulation(nn.Module):
def __init__(self, dim, n):
super().__init__()
self.n = n
self.proj = nn.Sequential(nn.SiLU(), nn.Linear(dim, n * dim, bias=True))
nn.init.constant_(self.proj[-1].weight, 0)
nn.init.constant_(self.proj[-1].bias, 0)

def forward(self, y):
return [m.unsqueeze(1) for m in self.proj(y).chunk(self.n, dim=1)]

class ModulatedLayerNorm(nn.LayerNorm):
def __init__(self, dim, **kwargs):
super().__init__(dim, **kwargs)
self.modulation = Modulation(dim, 2)
def forward(self, x, y):
scale, shift = self.modulation(y)
return super().forward(x) * (1 + scale) + shift

class DiTBlock(nn.Module):
def __init__(self, head_dim, num_heads, mlp_ratio=4.0):
super().__init__()
dim = head_dim * num_heads
mlp_hidden_dim = int(dim * mlp_ratio)
self.norm1 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.attn = Attention(head_dim, num_heads=num_heads, qkv_bias=True)
self.norm2 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, dim, bias=True),
)
self.scale_modulation = Modulation(dim, 2)

def forward(self, x, y):
# (B, N, D), (B, D) -> (B, N, D)
# N = H * W / patch_size**2, D = num_heads * head_dim
gate_msa, gate_mlp = self.scale_modulation(y)
x = x + gate_msa * self.attn(self.norm1(x, y))
x = x + gate_mlp * self.mlp(self.norm2(x, y))
return x

def get_pos_embed(in_dim, patch_size, dim, N=10000):
n = in_dim // patch_size # Number of patches per side
assert dim % 4 == 0, 'Embedding dimension must be multiple of 4!'
omega = 1/N**np.linspace(0, 1, dim // 4, endpoint=False) # [dim/4]
freqs = np.outer(np.arange(n), omega) # [n, dim/4]
embeds = repeat(np.stack([np.sin(freqs), np.cos(freqs)]),
' b n d -> b n k d', k=n) # [2, n, n, dim/4]
embeds_2d = np.concatenate([
rearrange(embeds, 'b n k d -> (k n) (b d)'), # [n*n, dim/2]
rearrange(embeds, 'b n k d -> (n k) (b d)'), # [n*n, dim/2]
], axis=1) # [n*n, dim]
return nn.Parameter(torch.tensor(embeds_2d).float().unsqueeze(0), # [1, n*n, dim]
requires_grad=False)

class DiT(nn.Module, ModelMixin):
def __init__(self, in_dim=32, channels=3, patch_size=2, depth=12,
head_dim=64, num_heads=6, mlp_ratio=4.0,
sig_embed_class=None, sig_embed_factor=0.5,
cond_embed_class=None, cond_dropout_prob=0.1, cond_num_classes=None):
super().__init__()
self.in_dim = in_dim
self.channels = channels
self.patch_size = patch_size
self.input_dims = (channels, in_dim, in_dim)

dim = head_dim * num_heads

self.pos_embed = get_pos_embed(in_dim, patch_size, dim)
self.x_embed = PatchEmbed(patch_size, channels, dim, bias=True)
self.sig_embed = (sig_embed_class or SigmaEmbedderSinCos)(
dim, scaling_factor=sig_embed_factor
)
self.conditional = cond_embed_class is not None
if self.conditional:
self.cond_embed = cond_embed_class(
dim, num_classes=cond_num_classes, dropout_prob=cond_dropout_prob
)

self.blocks = nn.ModuleList([
DiTBlock(head_dim, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
])

self.final_norm = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.final_linear = nn.Linear(dim, patch_size**2 * channels)
self.init()

def init(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)

# Initialize sigma embedding MLP:
nn.init.normal_(self.sig_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.sig_embed.mlp[2].weight, std=0.02)

# Zero-out output layers:
nn.init.constant_(self.final_linear.weight, 0)
nn.init.constant_(self.final_linear.bias, 0)

def unpatchify(self, x):
# (B, N, patchsize**2 * channels) -> (B, channels, H, W)
patches = self.in_dim // self.patch_size
return rearrange(x, 'b (ph pw) (psh psw c) -> b c (ph psh) (pw psw)',
ph=patches, pw=patches,
psh=self.patch_size, psw=self.patch_size)

def forward(self, x, sigma, cond=None):
# x: (B, C, H, W), sigma: Union[(B, 1, 1, 1), ()], cond: (B, *)
# returns: (B, C, H, W)
# N = num_patches, D = dim = head_dim * num_heads
x = self.x_embed(x) + self.pos_embed # (B, N, D)
y = self.sig_embed(x.shape[0], sigma.squeeze()) # (B, D)
if self.conditional:
assert x.shape[0] == cond.shape[0], 'Conditioning must have same batches as x!'
y += self.cond_embed(cond) # (B, D)
for block in self.blocks:
x = block(x, y) # (B, N, D)
x = self.final_linear(self.final_norm(x, y)) # (B, N, patchsize**2 * channels)
return self.unpatchify(x)

# Embedding table for conditioning on labels assumed to be in [0, num_classes),
# unconditional label encoded as: num_classes
class CondEmbedderLabel(nn.Module):
def __init__(self, hidden_size, num_classes, dropout_prob):
def __init__(self, hidden_size, num_classes, dropout_prob=0.1):
super().__init__()
self.embeddings = nn.Embedding(num_classes + 1, hidden_size)
self.null_cond = num_classes
Expand Down
143 changes: 143 additions & 0 deletions src/smalldiffusion/model_dit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import torch
import numpy as np
from torch import nn
from einops import rearrange, repeat
from .model import ModelMixin, Attention, SigmaEmbedderSinCos, CondSequential

## Diffusion transformer

class PatchEmbed(nn.Module):
def __init__(self, patch_size=16, channels=3, embed_dim=768, bias=True):
super().__init__()
self.proj = nn.Conv2d(channels, embed_dim, stride=patch_size, kernel_size=patch_size, bias=bias)
self.init()

def init(self): # Init like nn.Linear
w = self.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.proj.bias, 0)

def forward(self, x):
return rearrange(self.proj(x), 'b c h w -> b (h w) c')

class Modulation(nn.Module):
def __init__(self, dim, n):
super().__init__()
self.n = n
self.proj = nn.Sequential(nn.SiLU(), nn.Linear(dim, n * dim, bias=True))
nn.init.constant_(self.proj[-1].weight, 0)
nn.init.constant_(self.proj[-1].bias, 0)

def forward(self, y):
return [m.unsqueeze(1) for m in self.proj(y).chunk(self.n, dim=1)]

class ModulatedLayerNorm(nn.LayerNorm):
def __init__(self, dim, **kwargs):
super().__init__(dim, **kwargs)
self.modulation = Modulation(dim, 2)
def forward(self, x, y):
scale, shift = self.modulation(y)
return super().forward(x) * (1 + scale) + shift

class DiTBlock(nn.Module):
def __init__(self, head_dim, num_heads, mlp_ratio=4.0):
super().__init__()
dim = head_dim * num_heads
mlp_hidden_dim = int(dim * mlp_ratio)
self.norm1 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.attn = Attention(head_dim, num_heads=num_heads, qkv_bias=True)
self.norm2 = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, dim, bias=True),
)
self.scale_modulation = Modulation(dim, 2)

def forward(self, x, y):
# (B, N, D), (B, D) -> (B, N, D)
# N = H * W / patch_size**2, D = num_heads * head_dim
gate_msa, gate_mlp = self.scale_modulation(y)
x = x + gate_msa * self.attn(self.norm1(x, y))
x = x + gate_mlp * self.mlp(self.norm2(x, y))
return x

def get_pos_embed(in_dim, patch_size, dim, N=10000):
n = in_dim // patch_size # Number of patches per side
assert dim % 4 == 0, 'Embedding dimension must be multiple of 4!'
omega = 1/N**np.linspace(0, 1, dim // 4, endpoint=False) # [dim/4]
freqs = np.outer(np.arange(n), omega) # [n, dim/4]
embeds = repeat(np.stack([np.sin(freqs), np.cos(freqs)]),
' b n d -> b n k d', k=n) # [2, n, n, dim/4]
embeds_2d = np.concatenate([
rearrange(embeds, 'b n k d -> (k n) (b d)'), # [n*n, dim/2]
rearrange(embeds, 'b n k d -> (n k) (b d)'), # [n*n, dim/2]
], axis=1) # [n*n, dim]
return nn.Parameter(torch.tensor(embeds_2d).float().unsqueeze(0), # [1, n*n, dim]
requires_grad=False)

class DiT(nn.Module, ModelMixin):
def __init__(self, in_dim=32, channels=3, patch_size=2, depth=12,
head_dim=64, num_heads=6, mlp_ratio=4.0,
sig_embed_class=None, sig_embed_factor=0.5,
cond_embed=None):
super().__init__()
self.in_dim = in_dim
self.channels = channels
self.patch_size = patch_size
self.input_dims = (channels, in_dim, in_dim)

dim = head_dim * num_heads

self.pos_embed = get_pos_embed(in_dim, patch_size, dim)
self.x_embed = PatchEmbed(patch_size, channels, dim, bias=True)
self.sig_embed = (sig_embed_class or SigmaEmbedderSinCos)(
dim, scaling_factor=sig_embed_factor
)
self.cond_embed = cond_embed

self.blocks = CondSequential(*[
DiTBlock(head_dim, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
])

self.final_norm = ModulatedLayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.final_linear = nn.Linear(dim, patch_size**2 * channels)
self.init()

def init(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)

# Initialize sigma embedding MLP:
nn.init.normal_(self.sig_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.sig_embed.mlp[2].weight, std=0.02)

# Zero-out output layers:
nn.init.constant_(self.final_linear.weight, 0)
nn.init.constant_(self.final_linear.bias, 0)

def unpatchify(self, x):
# (B, N, patchsize**2 * channels) -> (B, channels, H, W)
patches = self.in_dim // self.patch_size
return rearrange(x, 'b (ph pw) (psh psw c) -> b c (ph psh) (pw psw)',
ph=patches, pw=patches,
psh=self.patch_size, psw=self.patch_size)

def forward(self, x, sigma, cond=None):
# x: (B, C, H, W), sigma: Union[(B, 1, 1, 1), ()], cond: (B, *)
# returns: (B, C, H, W)
# N = num_patches, D = dim = head_dim * num_heads
x = self.x_embed(x) + self.pos_embed # (B, N, D)
y = self.sig_embed(x.shape[0], sigma.squeeze()) # (B, D)
if self.cond_embed is not None:
assert cond is not None and x.shape[0] == cond.shape[0], \
'Conditioning must have same batches as x!'
y += self.cond_embed(cond) # (B, D)
x = self.blocks(x, y) # (B, N, D)
x = self.final_linear(self.final_norm(x, y)) # (B, N, patchsize**2 * channels)
return self.unpatchify(x)
2 changes: 1 addition & 1 deletion tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_uncond(self):
def test_cond(self):
for modifier in self.modifiers:
model = modifier(DiT)(in_dim=16, channels=3, patch_size=2, depth=4, head_dim=32, num_heads=6,
cond_embed_class=CondEmbedderLabel, cond_num_classes=10)
cond_embed=CondEmbedderLabel(32*6, 10))
x = torch.randn(10, 3, 16, 16)
sigma = torch.tensor(1)
labels = torch.tensor([1,2,3,4,5] + [10]*5)
Expand Down

0 comments on commit 3eb69c8

Please sign in to comment.