Skip to content

Commit

Permalink
mint adaption for dit/animatediff
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoyangLee authored and liuchuting committed Dec 26, 2024
1 parent b12cbc1 commit 6efa376
Show file tree
Hide file tree
Showing 44 changed files with 665 additions and 686 deletions.
24 changes: 8 additions & 16 deletions examples/animatediff/ad/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import mindspore as ms
import mindspore.nn as nn
from mindspore import ops
from mindspore import mint


class AutoencoderKL(nn.Cell):
Expand All @@ -39,25 +39,17 @@ def __init__(
self.encoder = Encoder(dtype=self.dtype, **ddconfig)
self.decoder = Decoder(dtype=self.dtype, **ddconfig)
assert ddconfig["double_z"]
self.quant_conv = nn.Conv2d(
2 * ddconfig["z_channels"], 2 * embed_dim, 1, pad_mode="valid", has_bias=True
).to_float(self.dtype)
self.post_quant_conv = nn.Conv2d(
embed_dim, ddconfig["z_channels"], 1, pad_mode="valid", has_bias=True
).to_float(self.dtype)
self.quant_conv = mint.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1, bias=True).to_float(self.dtype)
self.post_quant_conv = mint.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1, bias=True).to_float(self.dtype)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", ms.ops.standard_normal(3, colorize_nlabels, 1, 1))
self.register_buffer("colorize", mint.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

self.split = ops.Split(axis=1, output_num=2)
self.exp = ops.Exp()
self.stdnormal = ops.StandardNormal()

def init_from_ckpt(self, path, ignore_keys=list()):
sd = ms.load_checkpoint(path)["state_dict"]
keys = list(sd.keys())
Expand All @@ -77,8 +69,8 @@ def decode(self, z):
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
mean, logvar = self.split(moments)
logvar = ops.clip_by_value(logvar, -30.0, 20.0)
std = self.exp(0.5 * logvar)
x = mean + std * self.stdnormal(mean.shape)
mean, logvar = mint.split(moments, moments.shape[1] // 2, dim=1)
logvar = mint.clamp(logvar, -30.0, 20.0)
std = mint.exp(0.5 * logvar)
x = mean + std * mint.randn(*mean.shape)
return x
64 changes: 39 additions & 25 deletions examples/animatediff/ad/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,35 @@
from ad.modules.diffusionmodules.util import make_beta_schedule

import mindspore as ms
from mindspore import Parameter, Tensor
from mindspore import Parameter, Tensor, _no_grad
from mindspore import dtype as mstype
from mindspore import nn, ops
from mindspore import jit_class, mint, nn, ops

from mindone.utils.config import instantiate_from_config
from mindone.utils.misc import default, exists, extract_into_tensor

_logger = logging.getLogger(__name__)


@jit_class
class no_grad(_no_grad):
"""
A context manager that suppresses gradient memory allocation in PyNative mode.
"""

def __init__(self):
super().__init__()
self._pynative = ms.get_context("mode") == ms.PYNATIVE_MODE

def __enter__(self):
if self._pynative:
super().__enter__()

def __exit__(self, *args):
if self._pynative:
super().__exit__(*args)


class DDPM(nn.Cell):
def __init__(
self,
Expand Down Expand Up @@ -98,7 +117,6 @@ def __init__(
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.isnan = ops.IsNan()
self.register_schedule(
given_betas=given_betas,
beta_schedule=beta_schedule,
Expand Down Expand Up @@ -230,7 +248,6 @@ def __init__(
self.instantiate_cond_stage(cond_stage_config)

self.clip_denoised = False
self.uniform_int = ops.UniformInt()

self.restarted_from_ckpt = False
if ckpt_path is not None:
Expand Down Expand Up @@ -313,7 +330,7 @@ def get_latents_2d(self, x):
B, C, H, W = x.shape
if C != 3:
# b h w c -> b c h w
x = ops.transpose(x, (0, 3, 1, 2))
x = mint.permute(x, (0, 3, 1, 2))
# raise ValueError("Expect input shape (b 3 h w), but get {}".format(x.shape))

z = ops.stop_gradient(self.scale_factor * self.first_stage_model.encode(x))
Expand All @@ -325,13 +342,13 @@ def get_latents(self, x):
B, F, C, H, W = x.shape
if C != 3:
raise ValueError("Expect input shape (b f 3 h w), but get {}".format(x.shape))
x = ops.reshape(x, (-1, C, H, W))
x = mint.reshape(x, (-1, C, H, W))

z = ops.stop_gradient(self.scale_factor * self.first_stage_model.encode(x))

# (b*f c h w) -> (b f c h w) -> (b c f h w )
z = ops.reshape(z, (B, F, z.shape[1], z.shape[2], z.shape[3]))
z = ops.transpose(z, (0, 2, 1, 3, 4))
z = mint.reshape(z, (B, F, z.shape[1], z.shape[2], z.shape[3]))
z = mint.permute(z, (0, 2, 1, 3, 4))

return z

Expand Down Expand Up @@ -363,20 +380,17 @@ def construct(self, x: ms.Tensor, text_tokens: ms.Tensor, control=None, **kwargs
- assume unet3d input/output shape: (b c f h w)
unet2d input/output shape: (b c h w)
"""

# 1. get image/video latents z using vae
z = self.get_latents(x)

# 2. sample timestep and add noise to latents
t = self.uniform_int(
(x.shape[0],), Tensor(0, dtype=mstype.int32), Tensor(self.num_timesteps, dtype=mstype.int32)
)
noise = ops.randn_like(z)
with no_grad():
# 1. get image/video latents z using vae
z = self.get_latents(x)
# 2. get condition embeddings
cond = self.get_condition_embeddings(text_tokens, control)

# 3. sample timestep and add noise to latents
t = mint.randint(0, self.num_timesteps, (x.shape[0],))
noise = mint.randn_like(z)
noisy_latents, snr = self.add_noise(z, noise, t)

# 3. get condition embeddings
cond = self.get_condition_embeddings(text_tokens, control)

# 4. unet forward, predict conditioned on conditions
model_output = self.apply_model(
noisy_latents,
Expand All @@ -395,11 +409,11 @@ def construct(self, x: ms.Tensor, text_tokens: ms.Tensor, control=None, **kwargs
loss_sample = self.reduce_loss(loss_element)

if self.snr_gamma is not None:
snr_gamma = ops.ones_like(snr) * self.snr_gamma
snr_gamma = mint.ones_like(snr) * self.snr_gamma
# TODO: for v-pred, .../ (snr+1)
# TODO: for beta zero rescale, consider snr=0
# min{snr, gamma} / snr
loss_weight = ops.stack((snr, snr_gamma), axis=0).min(axis=0) / snr
loss_weight = mint.stack((snr, snr_gamma), dim=0).min(axis=0) / snr
loss = (loss_weight * loss_sample).mean()
else:
loss = loss_sample.mean()
Expand Down Expand Up @@ -444,7 +458,7 @@ def get_latents(self, x):
z = ops.stop_gradient(self.scale_factor * x)

# (b f c h w) -> (b c f h w )
z = ops.transpose(z, (0, 2, 1, 3, 4))
z = mint.permute(z, (0, 2, 1, 3, 4))
return z

def get_condition_embeddings(self, text_tokens, control=None):
Expand All @@ -468,13 +482,13 @@ def construct(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, **kwargs)
if self.conditioning_key is None:
out = self.diffusion_model(x, t, **kwargs)
elif self.conditioning_key == "concat":
x_concat = ops.concat((x, c_concat), axis=1)
x_concat = mint.concat((x, c_concat), dim=1)
out = self.diffusion_model(x_concat, t, **kwargs)
elif self.conditioning_key == "crossattn": # t2v task
context = c_crossattn
out = self.diffusion_model(x, t, context=context, **kwargs)
elif self.conditioning_key == "hybrid":
x_concat = ops.concat((x, c_concat), axis=1)
x_concat = mint.concat((x, c_concat), dim=1)
context = c_crossattn
out = self.diffusion_model(x_concat, t, context=context, **kwargs)
elif self.conditioning_key == "crossattn-adm":
Expand Down
Loading

0 comments on commit 6efa376

Please sign in to comment.