Skip to content

Commit

Permalink
Add cond to unet, custom log scale
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Nov 19, 2024
1 parent 2721d84 commit 1ae78a5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 22 deletions.
20 changes: 11 additions & 9 deletions src/smalldiffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,41 @@ def predict_eps(self, x, sigma, cond=None):
return self(x, sigma, cond=cond)

def predict_eps_cfg(self, x, sigma, cond, cfg_scale):
if cond is None:
return self.predict_eps(x, sigma)
if cond is None or cfg_scale == 0:
return self.predict_eps(x, sigma, cond=cond)
assert sigma.shape == tuple(), 'CFG sampling only supports singleton sigma!'
uncond = torch.full_like(cond, self.cond_embed.null_cond) # (B,)
eps_cond, eps_uncond = self.predict_eps( # (B,), (B,)
torch.cat([x, x]), sigma, torch.cat([cond, uncond]) # (2B,)
).chunk(2)
return eps_cond + cfg_scale * (eps_cond - eps_uncond)

def sigma_log_scale(batches, sigma, scaling_factor):
def get_sigma_embeds(batches, sigma, scaling_factor=0.5, log_scale=True):
if sigma.shape == torch.Size([]):
sigma = sigma.unsqueeze(0).repeat(batches)
else:
assert sigma.shape == (batches,), 'sigma.shape == [] or [batches]!'
return torch.log(sigma)*scaling_factor

def get_sigma_embeds(batches, sigma, scaling_factor=0.5):
s = sigma_log_scale(batches, sigma, scaling_factor).unsqueeze(1)
if log_scale:
sigma = torch.log(sigma)
s = sigma.unsqueeze(1) * scaling_factor
return torch.cat([torch.sin(s), torch.cos(s)], dim=1)

# A simple embedding that works just as well as usual sinusoidal embedding
class SigmaEmbedderSinCos(nn.Module):
def __init__(self, hidden_size, scaling_factor=0.5):
def __init__(self, hidden_size, scaling_factor=0.5, log_scale=True):
super().__init__()
self.scaling_factor = scaling_factor
self.log_scale = log_scale
self.mlp = nn.Sequential(
nn.Linear(2, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)

def forward(self, batches, sigma):
sig_embed = get_sigma_embeds(batches, sigma, self.scaling_factor) # (B, 2)
sig_embed = get_sigma_embeds(batches, sigma,
self.scaling_factor,
self.log_scale) # (B, 2)
return self.mlp(sig_embed) # (B, D)


Expand Down
33 changes: 20 additions & 13 deletions src/smalldiffusion/model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from itertools import pairwise
from torch import nn
from .model import (
sigma_log_scale, alpha, Attention, ModelMixin, CondSequential, SigmaEmbedderSinCos,
alpha, Attention, ModelMixin, CondSequential, SigmaEmbedderSinCos,
)

def Normalize(ch):
Expand Down Expand Up @@ -85,11 +85,13 @@ class Unet(nn.Module, ModelMixin):
def __init__(self, in_dim, in_ch, out_ch,
ch = 128,
ch_mult = (1,2,2,2),
embed_ch_mult = 4,
num_res_blocks = 2,
attn_resolutions = (16,),
dropout = 0.1,
resamp_with_conv = True,
sig_embed = None,
cond_embed = None,
):
super().__init__()

Expand All @@ -98,17 +100,18 @@ def __init__(self, in_dim, in_ch, out_ch,
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.input_dims = (in_ch, in_dim, in_dim)
self.temb_ch = self.ch*4
self.temb_ch = self.ch * embed_ch_mult

# sigma embedding
# Embeddings
self.sig_embed = sig_embed or SigmaEmbedderSinCos(self.temb_ch)
make_block = lambda in_ch, out_ch: ResnetBlock(
in_ch=in_ch, out_ch=out_ch, temb_channels=self.temb_ch, dropout=dropout
)
self.cond_embed = cond_embed

# Downsampling
curr_res = in_dim
in_ch_dim = [ch * m for m in (1,)+ch_mult]

# downsampling
self.conv_in = torch.nn.Conv2d(in_ch, self.ch, kernel_size=3, stride=1, padding=1)
self.downs = nn.ModuleList()
for i, (block_in, block_out) in enumerate(pairwise(in_ch_dim)):
Expand All @@ -125,14 +128,14 @@ def __init__(self, in_dim, in_ch, out_ch,
curr_res = curr_res // 2
self.downs.append(down)

# middle
# Middle
self.mid = CondSequential(
make_block(block_in, block_in),
AttnBlock(block_in),
make_block(block_in, block_in)
)

# upsampling
# Upsampling
self.ups = nn.ModuleList()
for i_level, (block_out, next_skip_in) in enumerate(pairwise(reversed(in_ch_dim))):
up = nn.Module()
Expand All @@ -151,7 +154,7 @@ def __init__(self, in_dim, in_ch, out_ch,
curr_res = curr_res * 2
self.ups.append(up)

# out
# Out
self.out_layer = nn.Sequential(
Normalize(block_in),
nn.SiLU(),
Expand All @@ -161,25 +164,29 @@ def __init__(self, in_dim, in_ch, out_ch,
def forward(self, x, sigma, cond=None):
assert x.shape[2] == x.shape[3] == self.in_dim

# sigma embedding
temb = self.sig_embed(x.shape[0], sigma.squeeze())
# Embeddings
emb = self.sig_embed(x.shape[0], sigma.squeeze())
if self.cond_embed is not None:
assert cond is not None and x.shape[0] == cond.shape[0], \
'Conditioning must have same batches as x!'
emb += self.cond_embed(cond)

# downsampling
hs = [self.conv_in(x)]
for down in self.downs:
for block in down.blocks:
h = block(hs[-1], temb)
h = block(hs[-1], emb)
hs.append(h)
if hasattr(down, 'downsample'):
hs.append(down.downsample(hs[-1]))

# middle
h = self.mid(hs[-1], temb)
h = self.mid(hs[-1], emb)

# upsampling
for up in self.ups:
for block in up.blocks:
h = block(torch.cat([h, hs.pop()], dim=1), temb)
h = block(torch.cat([h, hs.pop()], dim=1), emb)
if hasattr(up, 'upsample'):
h = up.upsample(h)

Expand Down

0 comments on commit 1ae78a5

Please sign in to comment.