diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 735fa725dd..31edf2913a 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -340,7 +340,6 @@ def get_attention_mask(self, attention_mask): attention_mask = attention_mask.to(ms.bool_) # use bool for sdpa return attention_mask - # @ms.jit # use graph mode def construct( self, hidden_states: ms.Tensor, diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py index cc4665bff0..bda9bf7ae0 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py @@ -2,9 +2,10 @@ from opensora.acceleration.communications import prepare_parallel_data from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info +from opensora.utils.ms_utils import no_grad import mindspore as ms -from mindspore import _no_grad, mint, nn, ops +from mindspore import nn, ops from mindone.diffusers.training_utils import compute_snr @@ -13,31 +14,11 @@ logger = logging.getLogger(__name__) -@ms.jit_class -class no_grad(_no_grad): - """ - A context manager that suppresses gradient memory allocation in PyNative mode. - """ - - def __init__(self): - super().__init__() - self._pynative = ms.get_context("mode") == ms.PYNATIVE_MODE - - def __enter__(self): - if self._pynative: - super().__enter__() - - def __exit__(self, *args): - if self._pynative: - super().__exit__(*args) - - class DiffusionWithLoss(nn.Cell): """An training pipeline for diffusion model Args: model (nn.Cell): A noise prediction model to denoise the encoded image latents. - vae (nn.Cell): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. noise_scheduler: (object): A class for noise scheduler, such as DDPM scheduler text_encoder / text_encoder_2 (nn.Cell): A text encoding model which accepts token ids and returns text embeddings in shape (T, D). T is the number of tokens, and D is the embedding dimension. @@ -48,11 +29,9 @@ def __init__( self, network: nn.Cell, noise_scheduler, - vae: nn.Cell = None, text_encoder: nn.Cell = None, text_encoder_2: nn.Cell = None, # not to use yet text_emb_cached: bool = True, - video_emb_cached: bool = False, use_image_num: int = 0, dtype=ms.float32, noise_offset: float = 0.0, @@ -61,7 +40,6 @@ def __init__( super().__init__() # TODO: is set_grad() necessary? self.network = network.set_grad() - self.vae = vae self.noise_scheduler = noise_scheduler self.prediction_type = self.noise_scheduler.config.prediction_type self.num_train_timesteps = self.noise_scheduler.config.num_train_timesteps @@ -72,7 +50,6 @@ def __init__( self.dtype = dtype self.text_emb_cached = text_emb_cached - self.video_emb_cached = video_emb_cached if self.text_emb_cached: self.text_encoder = None @@ -101,44 +78,6 @@ def get_condition_embeddings(self, text_tokens, encoder_attention_mask): text_emb = ops.stack(text_emb, axis=1) return text_emb - def vae_encode(self, x): - image_latents = self.vae.encode(x) - return image_latents - - def vae_decode(self, x): - """ - Args: - x: (b c h w), denoised latent - Return: - y: (b H W 3), batch of images, normalized to [0, 1] - """ - # b, c, f, h, w = x.shape - y = self.vae.decode(x) - y = ops.clip_by_value((y + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0) - - return y # b c f h w - - def get_latents(self, x): - if x.dim() == 5: - B, C, F, H, W = x.shape - if C != 3: - raise ValueError("Expect input shape (b 3 f h w), but get {}".format(x.shape)) - if self.use_image_num == 0: - z = self.vae_encode(x) # (b, c, f, h, w) - else: - videos, images = x[:, :, : -self.use_image_num], x[:, :, -self.use_image_num :] - videos = self.vae_encode(videos) # (b, c, f, h, w) - # (b, c, f, h, w) -> (b, f, c, h, w) -> (b*f, c, h, w) -> (b*f, c, 1, h, w) - images = images.permute(0, 2, 1, 3, 4).reshape(-1, C, H, W).unsqueeze(2) - images = self.vae_encode(images) # (b*f, c, 1, h, w) - # (b*f, c, 1, h, w) -> (b*f, c, h, w) -> (b, f, c, h, w) -> (b, c, f, h, w) - _, c, _, h, w = images.shape - images = images.squeeze(2).reshape(B, self.use_image_num, c, h, w).permute(0, 2, 1, 3, 4) - z = mint.cat([videos, images], dim=2) # b c 16+4, h, w - else: - raise ValueError("Incorrect Dimensions of x") - return z - def construct( self, x: ms.Tensor, @@ -150,7 +89,7 @@ def construct( Video diffusion model forward and loss computation for training Args: - x: pixel values of video frames, resized and normalized to shape (b c f+num_img h w) + x: the latent features of video frames (b c t' h' w'), where t' h' w' are the shape of latent features after vae's encoding. attention_mask: the mask for latent features of shape (b t' h' w'), where t' h' w' are the shape of latent features after vae's encoding. text_tokens: text tokens padded to fixed shape (B F L) or text embedding of shape (B F L D) if using text embedding cache encoder_attention_mask: the mask for text tokens/embeddings of a fixed shape (B F L) @@ -162,13 +101,10 @@ def construct( - inputs should matches dataloder output order - assume model input/output shape: (b c f+num_img h w) """ - # 1. get image/video latents z using vae + x = x.to(self.dtype) with no_grad(): - if not self.video_emb_cached: - x = ops.stop_gradient(self.get_latents(x)) - - # 2. get conditions + # get conditions if not self.text_emb_cached: text_embed = ops.stop_gradient(self.get_condition_embeddings(text_tokens, encoder_attention_mask)) else: @@ -282,7 +218,7 @@ def construct( Video diffusion model forward and loss computation for training Args: - x: pixel values of video frames, resized and normalized to shape (b c f+num_img h w) + x: the latent features of video frames (b c t' h' w'), where t' h' w' are the shape of latent features after vae's encoding. attention_mask: the mask for latent features of shape (b t' h' w'), where t' h' w' are the shape of latent features after vae's encoding. text_tokens: text tokens padded to fixed shape (B F L) or text embedding of shape (B F L D) if using text embedding cache encoder_attention_mask: the mask for text tokens/embeddings of a fixed shape (B F L) @@ -297,9 +233,6 @@ def construct( # 1. get image/video latents z using vae x = x.to(self.dtype) with no_grad(): - if not self.video_emb_cached: - x = ops.stop_gradient(self.get_latents(x)) - # 2. get conditions if not self.text_emb_cached: text_embed = ops.stop_gradient(self.get_condition_embeddings(text_tokens, encoder_attention_mask)) diff --git a/examples/opensora_pku/opensora/train/train_step.py b/examples/opensora_pku/opensora/train/train_step.py new file mode 100644 index 0000000000..dd3514749d --- /dev/null +++ b/examples/opensora_pku/opensora/train/train_step.py @@ -0,0 +1,187 @@ +"""Train step wrapper supporting setting drop overflow update, ema etc""" + +from packaging import version + +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor, nn, ops +from mindspore.boost.grad_accumulation import gradient_accumulation_op as _grad_accum_op +from mindspore.boost.grad_accumulation import gradient_clear_op as _grad_clear_op +from mindspore.common import RowTensor +from mindspore.common import dtype as mstype +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P + +_grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") + + +@_grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * F.cast(reciprocal(scale), F.dtype(grad)) + + +@_grad_scale.register("Tensor", "RowTensor") +def tensor_grad_scale_row_tensor(scale, grad): + return RowTensor( + grad.indices, + grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), + grad.dense_shape, + ) + + +class TrainOneStepWrapper(nn.TrainOneStepWithLossScaleCell): + """TrainStep with ema and clip grad. + + Args: + drop_overflow_update: if True, network will not be updated when gradient is overflow. + scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called + to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`, + the shape should be :math:`()` or :math:`(1,)`. + zero_helper (class): Zero redundancy optimizer(ZeRO) build helper, default is None. + + Returns: + Tuple of 3 Tensor, the loss, overflow flag and current loss scale value. + loss (Tensor) - A scalar, the loss value. + overflow (Tensor) - A scalar, whether overflow occur or not, the type is bool. + loss scale (Tensor) - The loss scale value, the shape is :math:`()` or :math:`(1,)`. + + """ + + def __init__( + self, + network, + optimizer, + scale_sense=1.0, + ema=None, + updates=0, + drop_overflow_update=True, + gradient_accumulation_steps=1, + clip_grad=False, + clip_norm=1.0, + verbose=False, + zero_helper=None, + ): + super().__init__(network, optimizer, scale_sense) + self.ema = ema + self.drop_overflow_update = drop_overflow_update + + assert isinstance(clip_grad, bool), f"Invalid type of clip_grad, got {type(clip_grad)}, expected bool" + assert clip_norm > 0.0 and isinstance(clip_norm, float), f"clip_norm must be float > 1.0, but got {clip_norm}" + self.clip_grad = clip_grad + self.clip_norm = clip_norm + + assert gradient_accumulation_steps >= 1 + self.accum_steps = gradient_accumulation_steps + if gradient_accumulation_steps > 1: + self.accumulated_grads = optimizer.parameters.clone(prefix="grad_accumulated_", init="zeros") + + self.cur_accum_step = ms.Parameter(ms.Tensor(0, dtype=ms.int32), name="accum_step") + self.zero = Tensor(0, ms.int32) + + self.verbose = verbose + self.is_cpu_device = context.get_context("device_target") == "CPU" # to support CPU in CI + self.skip_start_overflow_check = version.parse(ms.__version__) >= version.parse("2.1") + + self.map = ops.Map() + self.partial = ops.Partial() + + # zero init + self.zero_helper = zero_helper + self.zero_stage = zero_helper.zero_stage if zero_helper is not None else 0 + self.run_optimizer = zero_helper.run_optimizer if zero_helper is not None else self.optimizer + self.grad_reducer = self.grad_reducer if self.zero_stage == 0 else nn.Identity() + if self.zero_stage != 0: + self.zero_helper.split_params() + if gradient_accumulation_steps > 1: + self.accumulated_grads = optimizer.parameters.clone(prefix="grad_accumulated_", init="zeros") + + @ms.jit(jit_config="O1") + def construct(self, *inputs): + # compute loss + weights = self.weights + loss = self.network(*inputs) # mini-batch loss + scaling_sens = self.scale_sense + + # check loss overflow. (after ms2.1, it's done together with gradient overflow checking) + if self.skip_start_overflow_check: + status = Tensor([0] * 8, mstype.int32) + else: + if not self.is_cpu_device: + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) + else: + status = None + + scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) # loss scale value + + # 1. compute gradients (of the up-scaled loss w.r.t. the model weights) + grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) + + # Gradient communication + if self.zero_helper is not None: + grads = self.zero_helper.cal_gradients(grads) + + if self.accum_steps == 1: + grads = self.grad_reducer(grads) + scaling_sens = ops.depend(scaling_sens, grads) + + # 2. down-scale gradients by loss_scale. grads = grads / scaling_sense / grad_accum_steps + # also divide gradients by accumulation steps to avoid taking mean of the accumulated gradients later + grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) # accum_steps division is done later + + # 3. check gradient overflow + if not self.is_cpu_device: + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) + else: + overflow = ms.Tensor(False) + cond = ms.Tensor(False) + + # accumulate gradients and update model weights if no overflow or allow to update even when overflow + if (not self.drop_overflow_update) or (not overflow): + # 4. gradient accumulation if enabled + if self.accum_steps > 1: + # self.accumulated_grads += grads / accum_steps + loss = F.depend( + loss, self.hyper_map(F.partial(_grad_accum_op, self.accum_steps), self.accumulated_grads, grads) + ) + + # self.cur_accum_step += 1 + loss = F.depend(loss, ops.assign_add(self.cur_accum_step, Tensor(1, ms.int32))) + + if self.cur_accum_step >= self.accum_steps: + # 5. gradient reduction on distributed GPUs/NPUs + grads = self.grad_reducer(self.accumulated_grads) + + # 6. clip grad + if self.clip_grad: + grads = ops.clip_by_global_norm(grads, self.clip_norm) + # 7. optimize + loss = F.depend(loss, self.run_optimizer(grads)) + + # clear gradient accumulation states + loss = F.depend(loss, self.hyper_map(F.partial(_grad_clear_op), self.accumulated_grads)) + # self.cur_accum_step = 0 + loss = F.depend(loss, ops.assign(self.cur_accum_step, self.zero)) + else: + # update LR in each gradient step but not optimize net parameter + # to ensure the LR curve is consistent + # FIXME: for ms>=2.2, get_lr() will not increase global step by 1. we need to do it manually. + loss = F.depend(loss, self.optimizer.get_lr()) + else: + # 5. gradient reduction on distributed GPUs/NPUs + # 6. clip grad + if self.clip_grad: + grads = ops.clip_by_global_norm(grads, self.clip_norm) + # 7. optimize + loss = F.depend(loss, self.run_optimizer(grads)) + + # 8.ema + if self.ema is not None: + self.ema.ema_update() + # else: + # print("WARNING: Gradient overflow! update skipped.") # TODO: recover it after 910B in-graph print issue fixed + + return loss, cond, scaling_sens diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index eca4a415c1..df3deac688 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -2,14 +2,13 @@ import math import os import sys +import time import yaml import mindspore as ms -from mindspore import Model, nn +from mindspore import mint, nn from mindspore.communication.management import GlobalComm -from mindspore.train import get_metric_fn -from mindspore.train.callback import TimeMonitor mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) @@ -24,20 +23,19 @@ from opensora.models.diffusion.opensora.net_with_loss import DiffusionWithLoss, DiffusionWithLossEval from opensora.npu_config import npu_config from opensora.train.commons import create_loss_scaler, parse_args -from opensora.utils.callbacks import EMAEvalSwapCallback, PerfRecorderCallback +from opensora.train.train_step import TrainOneStepWrapper from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler from opensora.utils.ema import EMA from opensora.utils.message_utils import print_banner +from opensora.utils.ms_utils import no_grad from opensora.utils.utils import get_precision from mindone.diffusers.models.activations import SiLU from mindone.diffusers.schedulers import FlowMatchEulerDiscreteScheduler # CogVideoXDDIMScheduler, from mindone.diffusers.schedulers import DDPMScheduler -from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallbackEpoch, StopAtStepCallback -from mindone.trainers.checkpoint import resume_train_network +from mindone.trainers.checkpoint import CheckpointManager, resume_train_network from mindone.trainers.lr_schedule import create_scheduler from mindone.trainers.optim import create_optimizer -from mindone.trainers.train_step import TrainOneStepWrapper from mindone.trainers.zero import prepare_train_network from mindone.transformers import CLIPTextModelWithProjection, MT5EncoderModel, T5EncoderModel from mindone.utils.amp import auto_mixed_precision @@ -48,6 +46,40 @@ logger = logging.getLogger(__name__) +def set_train(modules): + for module in modules: + if isinstance(module, nn.Cell): + module.set_train(True) + + +def set_eval(modules): + for module in modules: + if isinstance(module, nn.Cell): + module.set_train(False) + + +def get_latents(vae, x, use_image_num=0): + if x.dim() == 5: + B, C, F, H, W = x.shape + if C != 3: + raise ValueError("Expect input shape (b 3 f h w), but get {}".format(x.shape)) + if use_image_num == 0: + z = vae.encode(x) # (b, c, f, h, w) + else: + videos, images = x[:, :, :-use_image_num], x[:, :, -use_image_num:] + videos = vae.encode(videos) # (b, c, f, h, w) + # (b, c, f, h, w) -> (b, f, c, h, w) -> (b*f, c, h, w) -> (b*f, c, 1, h, w) + images = images.permute(0, 2, 1, 3, 4).reshape(-1, C, H, W).unsqueeze(2) + images = vae.encode(images) # (b*f, c, 1, h, w) + # (b*f, c, 1, h, w) -> (b*f, c, h, w) -> (b, f, c, h, w) -> (b, c, f, h, w) + _, c, _, h, w = images.shape + images = images.squeeze(2).reshape(B, use_image_num, c, h, w).permute(0, 2, 1, 3, 4) + z = mint.cat([videos, images], dim=2) # b c 16+4, h, w + else: + raise ValueError("Incorrect Dimensions of x") + return z + + def set_all_reduce_fusion( params, split_num: int = 7, @@ -272,7 +304,6 @@ def main(args): latent_diffusion_with_loss = DiffusionWithLoss( model, noise_scheduler, - vae=vae, text_encoder=text_encoder_1, text_emb_cached=args.text_embed_cache, video_emb_cached=False, @@ -281,7 +312,6 @@ def main(args): noise_offset=args.noise_offset, snr_gamma=args.snr_gamma, ) - latent_diffusion_eval, metrics, eval_indexes = None, None, None # 3. create dataset # TODO: replace it with new dataset @@ -394,7 +424,6 @@ def main(args): latent_diffusion_eval = DiffusionWithLossEval( model, noise_scheduler, - vae=vae, text_encoder=text_encoder_1, text_emb_cached=args.text_embed_cache, video_emb_cached=False, @@ -403,8 +432,7 @@ def main(args): noise_offset=args.noise_offset, snr_gamma=args.snr_gamma, ) - metrics = {"val loss": get_metric_fn("loss")} - eval_indexes = [0, 1, 2] # the indexes of the output of eval network: loss. pred and label + # 4. build training utils: lr, optim, callbacks, trainer if args.scale_lr: learning_rate = args.start_learning_rate * args.train_batch_size * args.gradient_accumulation_steps * device_num @@ -414,7 +442,7 @@ def main(args): else: learning_rate = args.start_learning_rate end_learning_rate = args.end_learning_rate - + assert args.dataset_sink_mode is False, "Not support data sink mode=True!" if args.dataset_sink_mode and args.sink_size != -1: assert args.sink_size > 0, f"Expect that sink size is a positive integer, but got {args.sink_size}" steps_per_sink = args.sink_size @@ -447,8 +475,10 @@ def main(args): if args.checkpointing_steps is None: ckpt_save_interval = args.ckpt_save_interval step_mode = False + use_step_unit = False else: step_mode = not args.dataset_sink_mode + use_step_unit = True if not args.dataset_sink_mode: ckpt_save_interval = args.checkpointing_steps else: @@ -594,79 +624,18 @@ def main(args): encoder_attention_mask = ms.Tensor(shape=[_bs, None, args.model_max_length], dtype=ms.uint8) net_with_grads.set_inputs(video, attention_mask, text_tokens, encoder_attention_mask) logger.info("Dynamic inputs are initialized for training!") - - if not args.global_bf16: - model = Model( - net_with_grads, - eval_network=latent_diffusion_eval, - metrics=metrics, - eval_indexes=eval_indexes, - ) - else: - model = Model( - net_with_grads, - eval_network=latent_diffusion_eval, - metrics=metrics, - eval_indexes=eval_indexes, - amp_level="O0", - ) - # callbacks - callback = [TimeMonitor(args.log_interval), EMAEvalSwapCallback(ema)] - ofm_cb = OverflowMonitor() - callback.append(ofm_cb) - if args.max_train_steps is not None and args.max_train_steps > 0: - callback.append(StopAtStepCallback(args.max_train_steps, global_step=cur_iter)) - + assert args.parallel_mode != "optim", "Optimizer parallelism is not supported!" if args.parallel_mode == "optim": - cb_rank_id = None ckpt_save_dir = os.path.join(ckpt_dir, f"rank_{rank_id}") - output_dir = os.path.join(args.output_dir, "log", f"rank_{rank_id}") if args.ckpt_max_keep != 1: logger.warning("For semi-auto parallel training, the `ckpt_max_keep` is force to be 1.") ckpt_max_keep = 1 - integrated_save = False save_training_resume = False # TODO: support training resume else: - cb_rank_id = rank_id ckpt_save_dir = ckpt_dir - output_dir = None ckpt_max_keep = args.ckpt_max_keep - integrated_save = True save_training_resume = True - if rank_id == 0 or args.parallel_mode == "optim": - save_cb = EvalSaveCallback( - network=latent_diffusion_with_loss.network, - rank_id=cb_rank_id, - ckpt_save_dir=ckpt_save_dir, - output_dir=output_dir, - ema=ema, - save_ema_only=False, - ckpt_save_policy="latest_k", - ckpt_max_keep=ckpt_max_keep, - step_mode=step_mode, - use_step_unit=(args.checkpointing_steps is not None), - ckpt_save_interval=ckpt_save_interval, - log_interval=args.log_interval, - start_epoch=start_epoch, - model_name=args.model.replace("/", "-"), - record_lr=False, - integrated_save=integrated_save, - save_training_resume=save_training_resume, - ) - callback.append(save_cb) - if args.validate: - assert metrics is not None, "Val during training must set the metric functions" - rec_cb = PerfRecorderCallback( - save_dir=args.output_dir, - file_name="result_val.log", - resume=args.resume_from_checkpoint, - metric_names=list(metrics.keys()), - ) - callback.append(rec_cb) - if args.profile: - callback.append(ProfilerCallbackEpoch(2, 2, "./profile_data")) - # Train! total_batch_size = args.train_batch_size * device_num * args.gradient_accumulation_steps total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size @@ -736,17 +705,142 @@ def main(args): yaml.safe_dump(vars(args), stream=f, default_flow_style=False, sort_keys=False) # 6. train - model.fit( - sink_epochs, - dataloader, - valid_dataset=val_dataloader, - valid_frequency=args.val_interval, - callbacks=callback, - dataset_sink_mode=args.dataset_sink_mode, - valid_dataset_sink_mode=False, # TODO: add support? - sink_size=args.sink_size, - initial_epoch=start_epoch, - ) + if not os.path.exists(f"{args.output_dir}/rank_{rank_id}"): + os.makedirs(f"{args.output_dir}/rank_{rank_id}") + loss_log_file = open(f"{args.output_dir}/rank_{rank_id}/result.log", "w") + loss_log_file.write("step\tloss\ttrain_time(s)\n") + loss_log_file.flush() + if rank_id == 0: + ckpt_manager = CheckpointManager(ckpt_save_dir, "latest_k", k=ckpt_max_keep) + ds_iter = dataloader.create_dict_iterator(args.num_train_epochs - start_epoch) + for epoch in range(start_epoch, args.epochs): + start_time_e = time.time() + set_train(latent_diffusion_with_loss.network) + for step, data in enumerate(ds_iter): + start_time_s = time.time() + x = data["pixel_values"] + if vae is not None: + with no_grad(): + x = get_latents(vae, x, use_image_num=args.use_image_num) + + cur_global_step = epoch * dataloader_size + step + 1 # starting from 1 for logging + loss, overflow, scaling_sens = net_with_grads( + x, data["attention_mask"], data["text_embed"], data["encoder_attention_mask"] + ) + if isinstance(scaling_sens, ms.Parameter): + scaling_sens = scaling_sens.value() + + if overflow: + logger.warning( + f"Overflow occurs in step {cur_global_step}" + + (", drop update." if args.drop_overflow_update else ", still update.") + ) + + # log + step_time = time.time() - start_time_s + if step % args.log_interval == 0: + loss = float(loss.asnumpy()) + logger.info( + f"E: {epoch+1}, S: {step+1}, Loss ae: {loss:.4f}, ae loss scaler {scaling_sens}," + + f" Step time: {step_time*1000:.2f}ms" + ) + + loss_log_file.write(f"{cur_global_step}\t{loss:.7f}\t{step_time:.2f}\n") + loss_log_file.flush() + + if rank_id == 0 and step_mode: + cur_epoch = epoch + 1 + if (cur_global_step % ckpt_save_interval == 0) or (cur_global_step == total_train_steps): + ckpt_name = ( + f"{args.model}-e{cur_epoch}.ckpt" + if not use_step_unit + else f"{args.model}-s{cur_global_step}.ckpt" + ) + if ema is not None: + ema.swap_before_eval() + set_eval(latent_diffusion_with_loss.network) + ckpt_manager.save(latent_diffusion_with_loss.network, None, ckpt_name=ckpt_name, append_dict=None) + if save_training_resume: + ms.save_checkpoint( + net_with_grads, + os.path.join(ckpt_dir, "train_resume.ckpt"), + append_dict={ + "epoch_num": cur_epoch - 1, + "loss_scale": scaling_sens, + }, + ) + + if ema is not None: + ema.swap_after_eval() + set_train(latent_diffusion_with_loss.network) + + if cur_global_step == total_train_steps: + break + + epoch_cost = time.time() - start_time_e + per_step_time = epoch_cost / dataloader_size + cur_epoch = epoch + 1 + logger.info( + f"Epoch:[{int(cur_epoch):>3d}/{int(args.num_train_epochs):>3d}], " + f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time*1000:.2f}ms, " + ) + + if rank_id == 0 and not step_mode: + if (cur_epoch % ckpt_save_interval == 0) or (cur_epoch == args.epochs): + ckpt_name = ( + f"{args.model}-e{cur_epoch}.ckpt" if not use_step_unit else f"{args.model}-s{cur_global_step}.ckpt" + ) + if ema is not None: + ema.swap_before_eval() + set_eval(latent_diffusion_with_loss.network) + ckpt_manager.save(latent_diffusion_with_loss.network, None, ckpt_name=ckpt_name, append_dict=None) + if save_training_resume: + ms.save_checkpoint( + net_with_grads, + os.path.join(ckpt_dir, "train_resume.ckpt"), + append_dict={ + "epoch_num": cur_epoch - 1, + "loss_scale": scaling_sens, + }, + ) + if ema is not None: + ema.swap_after_eval() + set_train(latent_diffusion_with_loss.network) + + if rank_id == 0 and args.validate and (cur_epoch % args.val_interval == 0) or (cur_epoch == args.epochs): + # run validation + val_ds_iter = val_dataloader.create_dict_iterator(1) + if ema is not None: + ema.swap_before_eval() + set_eval(latent_diffusion_with_loss.network) + loss_val = 0 + with no_grad(): + val_time_e = time.time() + for iter, data in val_ds_iter: + val_time_s = time.time() + x = data["pixel_values"] + if vae is not None: + with no_grad(): + x = get_latents(vae, x, use_image_num=args.use_image_num) + + loss_val_iter = latent_diffusion_eval( + x, data["attention_mask"], data["text_embed"], data["encoder_attention_mask"] + ) + loss_val_iter = loss_val_iter.asnumpy() + step_time = time.time() - val_time_s + logger.info( + f"Validation [{iter+1}/{val_dataloader_size}]: Val loss {loss_val_iter:.4f}" + + f" Step time: {step_time*1000:.2f}ms" + ) + loss_val += loss_val_iter + loss_val = loss_val / val_dataloader_size + epoch_time = time.time() - val_time_e + logger.info(f"Validation finished within {epoch_time:.2f}s\tAverage Validation loss {loss_val:.4f}") + + if cur_global_step == total_train_steps: + break + # TODO: eval while training + loss_log_file.close() def parse_t2v_train_args(parser): diff --git a/examples/opensora_pku/opensora/utils/ms_utils.py b/examples/opensora_pku/opensora/utils/ms_utils.py index 52efa2ebb0..6c8e9659c2 100644 --- a/examples/opensora_pku/opensora/utils/ms_utils.py +++ b/examples/opensora_pku/opensora/utils/ms_utils.py @@ -5,6 +5,7 @@ from opensora.acceleration.parallel_states import initialize_sequence_parallel_state import mindspore as ms +from mindspore import _no_grad from mindspore.communication.management import get_group_size, get_rank, init from mindone.utils.seed import set_random_seed @@ -170,3 +171,22 @@ def init_env( ) initialize_sequence_parallel_state(sp_size) return rank_id, device_num + + +@ms.jit_class +class no_grad(_no_grad): + """ + A context manager that suppresses gradient memory allocation in PyNative mode. + """ + + def __init__(self): + super().__init__() + self._pynative = ms.get_context("mode") == ms.PYNATIVE_MODE + + def __enter__(self): + if self._pynative: + super().__enter__() + + def __exit__(self, *args): + if self._pynative: + super().__exit__(*args)