Skip to content

Commit

Permalink
Updated readme and toy example
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Feb 28, 2024
1 parent 62aa121 commit b77a4cc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# smalldiffusion

[![Tutorial blog post][blog-img]][blog-url]
[![Paper link][arxiv-img]][arxiv-url]
[![Pypi project][pypi-img]][pypi-url]
[![Build Status][build-img]][build-url]

A lightweight diffusion library for training and sampling from diffusion
models. It is built for easy experimentation when training new models and
developing new samplers, supporting minimal toy models to state-of-the-art
pretrained models. The [core of this library][diffusion-py] for diffusion
training and sampling is implemented in less than 100 lines of very readable
pytorch code. To install from [pypi](https://pypi.org/project/smalldiffusion/)::
pytorch code. To install from [pypi][pypi-url]:

```
pip install smalldiffusion
Expand All @@ -25,10 +28,10 @@ from smalldiffusion import Swissroll, TimeInputMLP, ScheduleLogLinear, training_
dataset = Swissroll(np.pi/2, 5*np.pi, 100)
loader = DataLoader(dataset, batch_size=2048)
model = TimeInputMLP(hidden_dims=(16,128,128,128,128,16))
schedule = ScheduleLogLinear(N=200)
trainer = training_loop(loader, model, schedule, epochs=10000)
schedule = ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)
trainer = training_loop(loader, model, schedule, epochs=15000)
losses = [ns.loss.item() for ns in trainer]
*xt, x0 = samples(model, schedule.sample_sigmas(20))
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=2)
```

Results on various toy datasets:
Expand Down Expand Up @@ -87,7 +90,7 @@ few examples on tweaking the parameter `gam`:
The core of smalldiffusion depends on the interaction between `data`, `model`
and `schedule` objects. Here we give a specification of these objects. For a
detailed introduction to diffusion models and the notation used in the code, see
the [accompanying tutorial](https://www.chenyang.co/diffusion.html).
the [accompanying tutorial][blog-url].

### Data
For training diffusion models, smalldiffusion supports pytorch [`Datasets` and
Expand Down Expand Up @@ -164,11 +167,11 @@ object. The generator will yield a sequence of `xt`s produced during
sampling. The sampling loop generalizes most commonly-used samplers:
- For DDPM [[Ho et. al. ]](https://arxiv.org/abs/2006.11239), use `gam=1, mu=0.5`.
- For DDIM [[Song et. al. ]](https://arxiv.org/abs/2010.02502), use `gam=1, mu=0`.
- For accelerated sampling [[Permenter and Yuan]](https://arxiv.org/abs/2306.04848), use `gam=2`.
- For accelerated sampling [[Permenter and Yuan]][arxiv-url], use `gam=2`.

For more details on how these sampling algorithms can be simplified, generalized
and implemented in only 5 lines of code, see Appendix A of [[Permenter and
Yuan]](https://arxiv.org/abs/2306.04848).
Yuan]][arxiv-url].


[diffusion-py]: https://github.com/yuanchenyang/smalldiffusion/blob/main/src/smalldiffusion/diffusion.py
Expand All @@ -177,3 +180,9 @@ Yuan]](https://arxiv.org/abs/2306.04848).
[stablediffusion]: https://github.com/yuanchenyang/smalldiffusion/blob/main/examples/stablediffusion.py
[build-img]: https://github.com/yuanchenyang/smalldiffusion/workflows/CI/badge.svg
[build-url]: https://github.com/yuanchenyang/smalldiffusion/actions?query=workflow%3ACI
[pypi-img]: https://img.shields.io/badge/pypi-blue
[pypi-url]: https://pypi.org/project/smalldiffusion/
[blog-img]: https://img.shields.io/badge/Tutorial-blogpost-blue
[blog-url]: https://www.chenyang.co/diffusion.html
[arxiv-img]: https://img.shields.io/badge/Paper-arxiv-blue
[arxiv-url]: https://arxiv.org/abs/2306.04848
Loading

0 comments on commit b77a4cc

Please sign in to comment.