Skip to content

Commit

Permalink
Updated examples
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Oct 16, 2024
1 parent c81360c commit e0a84b3
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 318 deletions.
40 changes: 40 additions & 0 deletions examples/cifar_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.utils import make_grid, save_image
from torch_ema import ExponentialMovingAverage as EMA
from tqdm import tqdm

from smalldiffusion import (
Unet, Scaled, ScheduleLogLinear, ScheduleSigmoid, samples, training_loop,
MappedDataset, img_train_transform, img_normalize
)

def main(train_batch_size=256, epochs=1000, sample_batch_size=64):
# Setup
a = Accelerator()
dataset = MappedDataset(CIFAR10('datasets', train=True, download=True,
transform=img_train_transform),
lambda x: x[0])
loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
train_schedule = ScheduleSigmoid(N=1000)
model = Scaled(Unet)(32, 3, 3, ch=128, ch_mult=(1, 2, 2, 2), attn_resolutions=(16,))

# Train
ema = EMA(model.parameters(), decay=0.9999)
ema.to(a.device)
for ns in training_loop(loader, model, train_schedule, epochs=epochs, lr=2e-4, accelerator=a):
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()

# Sample
sample_schedule = ScheduleLogLinear(sigma_min=0.01, sigma_max=35, N=1000)
with ema.average_parameters():
*xt, x0 = samples(model, sample_schedule.sample_sigmas(10), gam=2.1,
batchsize=sample_batch_size, accelerator=a)
save_image(img_normalize(make_grid(x0)), 'samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')

if __name__=='__main__':
main()
54 changes: 28 additions & 26 deletions examples/fashion_mnist_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,34 @@
from torch_ema import ExponentialMovingAverage as EMA
from tqdm import tqdm

from smalldiffusion import ScheduleDDPM, samples, training_loop, MappedDataset, DiT
from smalldiffusion import (
ScheduleDDPM, samples, training_loop, MappedDataset, DiT,
img_train_transform, img_normalize
)

# Setup
accelerator = Accelerator()
dataset = MappedDataset(FashionMNIST('datasets', train=True, download=True,
transform=tf.Compose([
tf.RandomHorizontalFlip(),
tf.ToTensor(),
tf.Lambda(lambda t: (t * 2) - 1)
])),
lambda x: x[0])
loader = DataLoader(dataset, batch_size=1024, shuffle=True)
schedule = ScheduleDDPM(beta_start=0.0001, beta_end=0.02, N=1000)
model = DiT(in_dim=28, channels=1,
patch_size=2, depth=6, head_dim=32, num_heads=6, mlp_ratio=4.0)
def main(train_batch_size=1024, epochs=300, sample_batch_size=64):
# Setup
a = Accelerator()
dataset = MappedDataset(FashionMNIST('datasets', train=True, download=True,
transform=img_train_transform),
lambda x: x[0])
loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
schedule = ScheduleDDPM(beta_start=0.0001, beta_end=0.02, N=1000)
model = DiT(in_dim=28, channels=1, patch_size=2, depth=6, head_dim=32, num_heads=6, mlp_ratio=4.0)

# Train
trainer = training_loop(loader, model, schedule, epochs=300, lr=1e-3, accelerator=accelerator)
ema = EMA(model.parameters(), decay=0.99)
ema.to(accelerator.device)
for ns in trainer:
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()
# Train
ema = EMA(model.parameters(), decay=0.99)
ema.to(a.device)
for ns in training_loop(loader, model, schedule, epochs=epochs, lr=1e-3, accelerator=a):
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()

# Sample
with ema.average_parameters():
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6, batchsize=64, accelerator=accelerator)
save_image(((make_grid(x0) + 1)/2).clamp(0, 1), 'fashion_mnist_samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')
# Sample
with ema.average_parameters():
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6,
batchsize=sample_batch_size, accelerator=a)
save_image(img_normalize(make_grid(x0)), 'samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')

if __name__=='__main__':
main()
61 changes: 31 additions & 30 deletions examples/fashion_mnist_dit_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,37 @@
from torch_ema import ExponentialMovingAverage as EMA
from tqdm import tqdm

from smalldiffusion import ScheduleDDPM, samples, training_loop, MappedDataset, DiT, CondEmbedderLabel
from smalldiffusion import (
ScheduleDDPM, samples, training_loop, MappedDataset, DiT, CondEmbedderLabel,
img_train_transform, img_normalize
)

# Setup
accelerator = Accelerator()
dataset = FashionMNIST('datasets', train=True, download=True,
transform=tf.Compose([
tf.RandomHorizontalFlip(),
tf.ToTensor(),
tf.Lambda(lambda t: (t * 2) - 1)
]))
loader = DataLoader(dataset, batch_size=1024, shuffle=True)
schedule = ScheduleDDPM(beta_start=0.0001, beta_end=0.02, N=1000)
cond_embed = CondEmbedderLabel(32*6, 10, 0.1)
model = DiT(in_dim=28, channels=1,
patch_size=2, depth=6, head_dim=32, num_heads=6, mlp_ratio=4.0,
cond_embed=cond_embed)
def main(train_batch_size=1024, epochs=300, sample_batch_size=40):
# Setup
a = Accelerator()
dataset = FashionMNIST('datasets', train=True, download=True,
transform=img_train_transform)
loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
schedule = ScheduleDDPM(beta_start=0.0001, beta_end=0.02, N=1000)
model = DiT(in_dim=28, channels=1,
patch_size=2, depth=6, head_dim=32, num_heads=6, mlp_ratio=4.0,
cond_embed=CondEmbedderLabel(32*6, 10, 0.1))

# Train
trainer = training_loop(loader, model, schedule, epochs=300, lr=1e-3, conditional=True,
accelerator=accelerator)
ema = EMA(model.parameters(), decay=0.99)
ema.to(accelerator.device)
for ns in trainer:
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()
# Train
ema = EMA(model.parameters(), decay=0.99)
ema.to(a.device)
for ns in training_loop(loader, model, schedule, epochs=epochs, lr=1e-3,
conditional=True, accelerator=a):
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()

# Sample
with ema.average_parameters():
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6, batchsize=40,
cond=list(range(10))*4,
accelerator=accelerator)
save_image(((make_grid(x0) + 1)/2).clamp(0, 1), 'fashion_mnist_samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')
# Sample
with ema.average_parameters():
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6, batchsize=sample_batch_size,
cond=torch.tensor([i%10 for i in range(sample_batch_size)]),
accelerator=a)
save_image(img_normalize(make_grid(x0)), 'samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')

if __name__=='__main__':
main()
54 changes: 28 additions & 26 deletions examples/fashion_mnist_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,34 @@
from torch_ema import ExponentialMovingAverage as EMA
from tqdm import tqdm

from smalldiffusion import ScheduleLogLinear, samples, training_loop, MappedDataset
from unet import Unet
from smalldiffusion import (
ScheduleLogLinear, samples, training_loop, MappedDataset, Unet, Scaled,
img_train_transform, img_normalize
)

# Setup
accelerator = Accelerator()
dataset = MappedDataset(FashionMNIST('datasets', train=True, download=True,
transform=tf.Compose([
tf.RandomHorizontalFlip(),
tf.ToTensor(),
tf.Lambda(lambda t: (t * 2) - 1)
])),
lambda x: x[0])
loader = DataLoader(dataset, batch_size=1024, shuffle=True)
schedule = ScheduleLogLinear(sigma_min=0.02, sigma_max=20, N=800)
model = Unet(dim=28, channels=1, dim_mults=(1,2,4,))
def main(train_batch_size=1024, epochs=300, sample_batch_size=64):
# Setup
a = Accelerator()
dataset = MappedDataset(FashionMNIST('datasets', train=True, download=True,
transform=img_train_transform),
lambda x: x[0])
loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
schedule = ScheduleLogLinear(sigma_min=0.01, sigma_max=20, N=800)
model = Scaled(Unet)(28, 1, 1, ch=64, ch_mult=(1, 1, 2), attn_resolutions=(14,))

# Train
trainer = training_loop(loader, model, schedule, epochs=300, lr=7e-4, accelerator=accelerator)
ema = EMA(model.parameters(), decay=0.99)
ema.to(accelerator.device)
for ns in trainer:
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()
# Train
ema = EMA(model.parameters(), decay=0.999)
ema.to(a.device)
for ns in training_loop(loader, model, schedule, epochs=epochs, lr=7e-4, accelerator=a):
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()

# Sample
with ema.average_parameters():
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6, batchsize=64, accelerator=accelerator)
save_image(((make_grid(x0) + 1)/2).clamp(0, 1), 'fashion_mnist_samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')
# Sample
with ema.average_parameters():
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6,
batchsize=sample_batch_size, accelerator=a)
save_image(img_normalize(make_grid(x0)), 'samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')

if __name__=='__main__':
main()
Loading

0 comments on commit e0a84b3

Please sign in to comment.