From c068a2fbde934e80758d7a004c791a4309a4a8c6 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:02:44 +0800 Subject: [PATCH] fix typos --- .../opensora/train/train_t2v_diffusers.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index df3deac688..82beeb1271 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -46,16 +46,14 @@ logger = logging.getLogger(__name__) -def set_train(modules): - for module in modules: - if isinstance(module, nn.Cell): - module.set_train(True) +def set_train(module): + if isinstance(module, nn.Cell): + module.set_train(True) -def set_eval(modules): - for module in modules: - if isinstance(module, nn.Cell): - module.set_train(False) +def set_eval(module): + if isinstance(module, nn.Cell): + module.set_train(False) def get_latents(vae, x, use_image_num=0): @@ -306,7 +304,6 @@ def main(args): noise_scheduler, text_encoder=text_encoder_1, text_emb_cached=args.text_embed_cache, - video_emb_cached=False, use_image_num=args.use_image_num, dtype=model_dtype, noise_offset=args.noise_offset, @@ -426,7 +423,6 @@ def main(args): noise_scheduler, text_encoder=text_encoder_1, text_emb_cached=args.text_embed_cache, - video_emb_cached=False, use_image_num=args.use_image_num, dtype=model_dtype, noise_offset=args.noise_offset, @@ -713,7 +709,7 @@ def main(args): if rank_id == 0: ckpt_manager = CheckpointManager(ckpt_save_dir, "latest_k", k=ckpt_max_keep) ds_iter = dataloader.create_dict_iterator(args.num_train_epochs - start_epoch) - for epoch in range(start_epoch, args.epochs): + for epoch in range(start_epoch, args.num_train_epochs): start_time_e = time.time() set_train(latent_diffusion_with_loss.network) for step, data in enumerate(ds_iter): @@ -786,7 +782,7 @@ def main(args): ) if rank_id == 0 and not step_mode: - if (cur_epoch % ckpt_save_interval == 0) or (cur_epoch == args.epochs): + if (cur_epoch % ckpt_save_interval == 0) or (cur_epoch == args.num_train_epochs): ckpt_name = ( f"{args.model}-e{cur_epoch}.ckpt" if not use_step_unit else f"{args.model}-s{cur_global_step}.ckpt" ) @@ -807,7 +803,12 @@ def main(args): ema.swap_after_eval() set_train(latent_diffusion_with_loss.network) - if rank_id == 0 and args.validate and (cur_epoch % args.val_interval == 0) or (cur_epoch == args.epochs): + if ( + rank_id == 0 + and args.validate + and (cur_epoch % args.val_interval == 0) + or (cur_epoch == args.num_train_epochs) + ): # run validation val_ds_iter = val_dataloader.create_dict_iterator(1) if ema is not None: