Skip to content

Commit

Permalink
update t2v training logic and put train_step as jit
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Nov 28, 2024
1 parent 1b79586 commit 3af23f9
Show file tree
Hide file tree
Showing 5 changed files with 393 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ def get_attention_mask(self, attention_mask):
attention_mask = attention_mask.to(ms.bool_) # use bool for sdpa
return attention_mask

# @ms.jit # use graph mode
def construct(
self,
hidden_states: ms.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from opensora.acceleration.communications import prepare_parallel_data
from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info
from opensora.utils.ms_utils import no_grad

import mindspore as ms
from mindspore import _no_grad, mint, nn, ops
from mindspore import nn, ops

from mindone.diffusers.training_utils import compute_snr

Expand All @@ -13,31 +14,11 @@
logger = logging.getLogger(__name__)


@ms.jit_class
class no_grad(_no_grad):
"""
A context manager that suppresses gradient memory allocation in PyNative mode.
"""

def __init__(self):
super().__init__()
self._pynative = ms.get_context("mode") == ms.PYNATIVE_MODE

def __enter__(self):
if self._pynative:
super().__enter__()

def __exit__(self, *args):
if self._pynative:
super().__exit__(*args)


class DiffusionWithLoss(nn.Cell):
"""An training pipeline for diffusion model
Args:
model (nn.Cell): A noise prediction model to denoise the encoded image latents.
vae (nn.Cell): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
noise_scheduler: (object): A class for noise scheduler, such as DDPM scheduler
text_encoder / text_encoder_2 (nn.Cell): A text encoding model which accepts token ids and returns text embeddings in shape (T, D).
T is the number of tokens, and D is the embedding dimension.
Expand All @@ -48,11 +29,9 @@ def __init__(
self,
network: nn.Cell,
noise_scheduler,
vae: nn.Cell = None,
text_encoder: nn.Cell = None,
text_encoder_2: nn.Cell = None, # not to use yet
text_emb_cached: bool = True,
video_emb_cached: bool = False,
use_image_num: int = 0,
dtype=ms.float32,
noise_offset: float = 0.0,
Expand All @@ -61,7 +40,6 @@ def __init__(
super().__init__()
# TODO: is set_grad() necessary?
self.network = network.set_grad()
self.vae = vae
self.noise_scheduler = noise_scheduler
self.prediction_type = self.noise_scheduler.config.prediction_type
self.num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
Expand All @@ -72,7 +50,6 @@ def __init__(
self.dtype = dtype

self.text_emb_cached = text_emb_cached
self.video_emb_cached = video_emb_cached

if self.text_emb_cached:
self.text_encoder = None
Expand Down Expand Up @@ -101,44 +78,6 @@ def get_condition_embeddings(self, text_tokens, encoder_attention_mask):
text_emb = ops.stack(text_emb, axis=1)
return text_emb

def vae_encode(self, x):
image_latents = self.vae.encode(x)
return image_latents

def vae_decode(self, x):
"""
Args:
x: (b c h w), denoised latent
Return:
y: (b H W 3), batch of images, normalized to [0, 1]
"""
# b, c, f, h, w = x.shape
y = self.vae.decode(x)
y = ops.clip_by_value((y + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0)

return y # b c f h w

def get_latents(self, x):
if x.dim() == 5:
B, C, F, H, W = x.shape
if C != 3:
raise ValueError("Expect input shape (b 3 f h w), but get {}".format(x.shape))
if self.use_image_num == 0:
z = self.vae_encode(x) # (b, c, f, h, w)
else:
videos, images = x[:, :, : -self.use_image_num], x[:, :, -self.use_image_num :]
videos = self.vae_encode(videos) # (b, c, f, h, w)
# (b, c, f, h, w) -> (b, f, c, h, w) -> (b*f, c, h, w) -> (b*f, c, 1, h, w)
images = images.permute(0, 2, 1, 3, 4).reshape(-1, C, H, W).unsqueeze(2)
images = self.vae_encode(images) # (b*f, c, 1, h, w)
# (b*f, c, 1, h, w) -> (b*f, c, h, w) -> (b, f, c, h, w) -> (b, c, f, h, w)
_, c, _, h, w = images.shape
images = images.squeeze(2).reshape(B, self.use_image_num, c, h, w).permute(0, 2, 1, 3, 4)
z = mint.cat([videos, images], dim=2) # b c 16+4, h, w
else:
raise ValueError("Incorrect Dimensions of x")
return z

def construct(
self,
x: ms.Tensor,
Expand All @@ -150,7 +89,7 @@ def construct(
Video diffusion model forward and loss computation for training
Args:
x: pixel values of video frames, resized and normalized to shape (b c f+num_img h w)
x: the latent features of video frames (b c t' h' w'), where t' h' w' are the shape of latent features after vae's encoding.
attention_mask: the mask for latent features of shape (b t' h' w'), where t' h' w' are the shape of latent features after vae's encoding.
text_tokens: text tokens padded to fixed shape (B F L) or text embedding of shape (B F L D) if using text embedding cache
encoder_attention_mask: the mask for text tokens/embeddings of a fixed shape (B F L)
Expand All @@ -162,13 +101,10 @@ def construct(
- inputs should matches dataloder output order
- assume model input/output shape: (b c f+num_img h w)
"""
# 1. get image/video latents z using vae

x = x.to(self.dtype)
with no_grad():
if not self.video_emb_cached:
x = ops.stop_gradient(self.get_latents(x))

# 2. get conditions
# get conditions
if not self.text_emb_cached:
text_embed = ops.stop_gradient(self.get_condition_embeddings(text_tokens, encoder_attention_mask))
else:
Expand Down Expand Up @@ -282,7 +218,7 @@ def construct(
Video diffusion model forward and loss computation for training
Args:
x: pixel values of video frames, resized and normalized to shape (b c f+num_img h w)
x: the latent features of video frames (b c t' h' w'), where t' h' w' are the shape of latent features after vae's encoding.
attention_mask: the mask for latent features of shape (b t' h' w'), where t' h' w' are the shape of latent features after vae's encoding.
text_tokens: text tokens padded to fixed shape (B F L) or text embedding of shape (B F L D) if using text embedding cache
encoder_attention_mask: the mask for text tokens/embeddings of a fixed shape (B F L)
Expand All @@ -297,9 +233,6 @@ def construct(
# 1. get image/video latents z using vae
x = x.to(self.dtype)
with no_grad():
if not self.video_emb_cached:
x = ops.stop_gradient(self.get_latents(x))

# 2. get conditions
if not self.text_emb_cached:
text_embed = ops.stop_gradient(self.get_condition_embeddings(text_tokens, encoder_attention_mask))
Expand Down
187 changes: 187 additions & 0 deletions examples/opensora_pku/opensora/train/train_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""Train step wrapper supporting setting drop overflow update, ema etc"""

from packaging import version

import mindspore as ms
import mindspore.context as context
from mindspore import Tensor, nn, ops
from mindspore.boost.grad_accumulation import gradient_accumulation_op as _grad_accum_op
from mindspore.boost.grad_accumulation import gradient_clear_op as _grad_clear_op
from mindspore.common import RowTensor
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P

_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")


@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))


@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_row_tensor(scale, grad):
return RowTensor(
grad.indices,
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
grad.dense_shape,
)


class TrainOneStepWrapper(nn.TrainOneStepWithLossScaleCell):
"""TrainStep with ema and clip grad.
Args:
drop_overflow_update: if True, network will not be updated when gradient is overflow.
scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called
to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`,
the shape should be :math:`()` or :math:`(1,)`.
zero_helper (class): Zero redundancy optimizer(ZeRO) build helper, default is None.
Returns:
Tuple of 3 Tensor, the loss, overflow flag and current loss scale value.
loss (Tensor) - A scalar, the loss value.
overflow (Tensor) - A scalar, whether overflow occur or not, the type is bool.
loss scale (Tensor) - The loss scale value, the shape is :math:`()` or :math:`(1,)`.
"""

def __init__(
self,
network,
optimizer,
scale_sense=1.0,
ema=None,
updates=0,
drop_overflow_update=True,
gradient_accumulation_steps=1,
clip_grad=False,
clip_norm=1.0,
verbose=False,
zero_helper=None,
):
super().__init__(network, optimizer, scale_sense)
self.ema = ema
self.drop_overflow_update = drop_overflow_update

assert isinstance(clip_grad, bool), f"Invalid type of clip_grad, got {type(clip_grad)}, expected bool"
assert clip_norm > 0.0 and isinstance(clip_norm, float), f"clip_norm must be float > 1.0, but got {clip_norm}"
self.clip_grad = clip_grad
self.clip_norm = clip_norm

assert gradient_accumulation_steps >= 1
self.accum_steps = gradient_accumulation_steps
if gradient_accumulation_steps > 1:
self.accumulated_grads = optimizer.parameters.clone(prefix="grad_accumulated_", init="zeros")

self.cur_accum_step = ms.Parameter(ms.Tensor(0, dtype=ms.int32), name="accum_step")
self.zero = Tensor(0, ms.int32)

self.verbose = verbose
self.is_cpu_device = context.get_context("device_target") == "CPU" # to support CPU in CI
self.skip_start_overflow_check = version.parse(ms.__version__) >= version.parse("2.1")

self.map = ops.Map()
self.partial = ops.Partial()

# zero init
self.zero_helper = zero_helper
self.zero_stage = zero_helper.zero_stage if zero_helper is not None else 0
self.run_optimizer = zero_helper.run_optimizer if zero_helper is not None else self.optimizer
self.grad_reducer = self.grad_reducer if self.zero_stage == 0 else nn.Identity()
if self.zero_stage != 0:
self.zero_helper.split_params()
if gradient_accumulation_steps > 1:
self.accumulated_grads = optimizer.parameters.clone(prefix="grad_accumulated_", init="zeros")

@ms.jit(jit_config="O1")
def construct(self, *inputs):
# compute loss
weights = self.weights
loss = self.network(*inputs) # mini-batch loss
scaling_sens = self.scale_sense

# check loss overflow. (after ms2.1, it's done together with gradient overflow checking)
if self.skip_start_overflow_check:
status = Tensor([0] * 8, mstype.int32)
else:
if not self.is_cpu_device:
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
else:
status = None

scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) # loss scale value

# 1. compute gradients (of the up-scaled loss w.r.t. the model weights)
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)

# Gradient communication
if self.zero_helper is not None:
grads = self.zero_helper.cal_gradients(grads)

if self.accum_steps == 1:
grads = self.grad_reducer(grads)
scaling_sens = ops.depend(scaling_sens, grads)

# 2. down-scale gradients by loss_scale. grads = grads / scaling_sense / grad_accum_steps
# also divide gradients by accumulation steps to avoid taking mean of the accumulated gradients later
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) # accum_steps division is done later

# 3. check gradient overflow
if not self.is_cpu_device:
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
else:
overflow = ms.Tensor(False)
cond = ms.Tensor(False)

# accumulate gradients and update model weights if no overflow or allow to update even when overflow
if (not self.drop_overflow_update) or (not overflow):
# 4. gradient accumulation if enabled
if self.accum_steps > 1:
# self.accumulated_grads += grads / accum_steps
loss = F.depend(
loss, self.hyper_map(F.partial(_grad_accum_op, self.accum_steps), self.accumulated_grads, grads)
)

# self.cur_accum_step += 1
loss = F.depend(loss, ops.assign_add(self.cur_accum_step, Tensor(1, ms.int32)))

if self.cur_accum_step >= self.accum_steps:
# 5. gradient reduction on distributed GPUs/NPUs
grads = self.grad_reducer(self.accumulated_grads)

# 6. clip grad
if self.clip_grad:
grads = ops.clip_by_global_norm(grads, self.clip_norm)
# 7. optimize
loss = F.depend(loss, self.run_optimizer(grads))

# clear gradient accumulation states
loss = F.depend(loss, self.hyper_map(F.partial(_grad_clear_op), self.accumulated_grads))
# self.cur_accum_step = 0
loss = F.depend(loss, ops.assign(self.cur_accum_step, self.zero))
else:
# update LR in each gradient step but not optimize net parameter
# to ensure the LR curve is consistent
# FIXME: for ms>=2.2, get_lr() will not increase global step by 1. we need to do it manually.
loss = F.depend(loss, self.optimizer.get_lr())
else:
# 5. gradient reduction on distributed GPUs/NPUs
# 6. clip grad
if self.clip_grad:
grads = ops.clip_by_global_norm(grads, self.clip_norm)
# 7. optimize
loss = F.depend(loss, self.run_optimizer(grads))

# 8.ema
if self.ema is not None:
self.ema.ema_update()
# else:
# print("WARNING: Gradient overflow! update skipped.") # TODO: recover it after 910B in-graph print issue fixed

return loss, cond, scaling_sens
Loading

0 comments on commit 3af23f9

Please sign in to comment.