From 16cdd9bf1ad2f0796e94e223a2885d9ab768207c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:44:04 +0800 Subject: [PATCH 001/133] op-v0.3 init --- examples/opensora_pku/README.md | 34 +- examples/opensora_pku/examples/rec_image.py | 64 +- examples/opensora_pku/examples/rec_video.py | 73 +- .../opensora_pku/examples/rec_video_folder.py | 56 +- .../opensora_pku/opensora/models/__init__.py | 2 +- .../models/causalvideovae/__init__.py | 92 +- .../models/causalvideovae/eval/cal_fvd.py | 10 +- .../models/causalvideovae/eval/cal_lpips.py | 6 +- .../models/causalvideovae/eval/cal_psnr.py | 6 +- .../models/causalvideovae/eval/cal_ssim.py | 6 +- .../models/causalvideovae/eval/eval.py | 262 +++++ .../causalvideovae/eval/fvd/styleganv/fvd.py | 122 +++ .../causalvideovae/eval/fvd/videogpt/fvd.py | 164 +++ .../eval/fvd/videogpt/ms_i3d.py | 382 +++++++ .../causalvideovae/eval/script/cal_fvd.sh | 9 + .../causalvideovae/eval/script/cal_lpips.sh | 2 +- .../causalvideovae/eval/script/cal_psnr.sh | 2 +- .../causalvideovae/eval/script/cal_ssim.sh | 2 +- .../models/causalvideovae/model/__init__.py | 43 +- .../causalvideovae/model/dataset_videobase.py | 1 + .../causalvideovae/model/losses/__init__.py | 1 + .../model/losses/discriminator.py | 142 +-- .../causalvideovae/model/losses/lpips.py | 6 +- .../model/losses/net_with_loss.py | 81 +- .../model/losses/perceptual_loss.py | 261 +---- .../causalvideovae/model/modules/__init__.py | 22 +- .../causalvideovae/model/modules/attention.py | 176 +--- .../causalvideovae/model/modules/conv.py | 141 +-- .../causalvideovae/model/modules/normalize.py | 99 +- .../causalvideovae/model/modules/ops.py | 35 +- .../model/modules/resnet_block.py | 120 +-- .../model/modules/updownsample.py | 142 +-- .../causalvideovae/model/modules/wavelet.py | 259 +++++ .../models/causalvideovae/model/registry.py | 14 + .../model/utils/distrib_utils.py | 43 + .../causalvideovae/model/utils/video_utils.py | 4 +- .../causalvideovae/model/vae/__init__.py | 6 + .../model/vae/modeling_causalvae.py | 946 ++++++++++++++++++ .../model/vae/modeling_wfvae.py | 799 +++++++++++++++ .../causalvideovae/sample/rec_video_vae.py | 161 +++ .../models/diffusion/opensora/modules.py | 14 +- examples/opensora_pku/opensora/npu_config.py | 344 +++++++ .../opensora/sample/caption_refiner.py | 36 + .../opensora_pku/opensora/sample/rec_image.py | 164 +++ .../opensora_pku/opensora/sample/rec_video.py | 221 ++++ .../opensora_pku/opensora/train/commons.py | 2 +- .../opensora/train/train_causalvae.py | 89 +- examples/opensora_pku/requirements.txt | 11 +- .../scripts/causalvae/rec_image.sh | 7 +- .../scripts/causalvae/rec_video.sh | 13 +- .../scripts/causalvae/rec_video_folder.sh | 4 +- .../scripts/causalvae/train_with_gan_loss.sh | 14 +- .../train_with_gan_loss_multi_device.sh | 4 +- .../scripts/causalvae/wfvae_4dim.json | 23 + examples/opensora_pku/tests/test_wavelet.py | 71 ++ examples/opensora_pku/tests/torch_wavelet.py | 327 ++++++ .../convert_pytorch_ckpt_to_safetensors.py | 12 +- .../tools/model_conversion/convert_wfvae.py | 79 ++ .../inflate_vae2d_to_vae3d.py | 3 +- 59 files changed, 5080 insertions(+), 1154 deletions(-) create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/eval/eval.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/videogpt/fvd.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/videogpt/ms_i3d.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_fvd.sh create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/model/registry.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/model/utils/distrib_utils.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/model/vae/__init__.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_causalvae.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py create mode 100644 examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py create mode 100644 examples/opensora_pku/opensora/npu_config.py create mode 100644 examples/opensora_pku/opensora/sample/caption_refiner.py create mode 100644 examples/opensora_pku/opensora/sample/rec_image.py create mode 100644 examples/opensora_pku/opensora/sample/rec_video.py create mode 100644 examples/opensora_pku/scripts/causalvae/wfvae_4dim.json create mode 100644 examples/opensora_pku/tests/test_wavelet.py create mode 100644 examples/opensora_pku/tests/torch_wavelet.py create mode 100644 examples/opensora_pku/tools/model_conversion/convert_wfvae.py diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 3b2ef6720f..670437fb64 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -120,19 +120,16 @@ For EulerOS, instructions on ffmpeg and decord installation are as follows. ## Model Weights -### Open-Sora-Plan v1.2.0 Model Weights +### Open-Sora-Plan v1.3.0 Model Weights -Please download the torch checkpoint of mT5-xxl from [google/mt5-xxl](https://huggingface.co/google/mt5-xxl/tree/main), and download the opensora v1.2.0 models' weights from [LanguageBind/Open-Sora-Plan-v1.2.0](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.2.0/tree/main). Place them under `examples/opensora_pku` as shown below: +Please download the torch checkpoint of mT5-xxl from [google/mt5-xxl](https://huggingface.co/google/mt5-xxl/tree/main), and download the opensora v1.2.0 models' weights from [LanguageBind/Open-Sora-Plan-v1.3.0](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main). Place them under `examples/opensora_pku` as shown below: ```bash mindone/examples/opensora_pku ├───LanguageBind -│ └───Open-Sora-Plan-v1.2.0 -│ ├───1x480p/ -│ ├───29x480p/ -│ ├───29x720p/ -│ ├───93x480p/ -│ ├───93x480p_i2v/ -│ ├───93x720p/ +│ └───Open-Sora-Plan-v1.3.0 +│ ├───any93x640x640/ +│ ├───any93x640x640_i2v/ +│ ├───prompt_refiner/ │ └───vae/ └───google/ └───mt5-xxl/ @@ -147,7 +144,7 @@ mindone/examples/opensora_pku Currently, we can load `.safetensors` files directly in MindSpore, but not `.bin` or `.ckpt` files. We recommend you to convert the `vae/checkpoint.ckpt` and `mt5-xxl/pytorch_model.bin` files to `.safetensor` files manually by running the following commands: ```shell -python tools/model_conversion/convert_pytorch_ckpt_to_safetensors.py --src LanguageBind/Open-Sora-Plan-v1.2.0/vae/checkpoint.ckpt --target LanguageBind/Open-Sora-Plan-v1.2.0/vae/diffusion_pytorch_model.safetensors --config LanguageBind/Open-Sora-Plan-v1.2.0/vae/config.json +python tools/model_conversion/convert_wfvae.py --src LanguageBind/Open-Sora-Plan-v1.3.0/vae/merged.ckpt --target LanguageBind/Open-Sora-Plan-v1.3.0/vae/diffusion_pytorch_model.safetensors --config LanguageBind/Open-Sora-Plan-v1.3.0/vae/config.json python tools/model_conversion/convert_pytorch_ckpt_to_safetensors.py --src google/mt5-xxl/pytorch_model.bin --target google/mt5-xxl/model.safetensors --config google/mt5-xxl/config.json ``` @@ -161,23 +158,26 @@ Once the checkpoint files have all been prepared, you can refer to the inference You can run video-to-video reconstruction task using `scripts/causalvae/rec_video.sh`: ```bash python examples/rec_video.py \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae \ + --ae "WFVAEModel_D8_4x8x8" \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --video_path test.mp4 \ --rec_path rec.mp4 \ --device Ascend \ --sample_rate 1 \ - --num_frames 65 \ - --height 480 \ - --width 640 \ + --num_frames 61 \ + --height 512 \ + --width 512 \ + --fps 30 \ --enable_tiling \ - --tile_overlap_factor 0.125 \ - --save_memory + --mode 1 \ ``` Please change the `--video_path` to the existing video file path and `--rec_path` to the reconstructed video file path. You can set `--grid` to save the original video and the reconstructed video in the same output file. You can also run video reconstruction given an input video folder. See `scripts/causalvae/rec_video_folder.sh`. -### Open-Sora-Plan v1.2.0 Command Line Inference +### Open-Sora-Plan v1.3.0 Command Line Inference + +**To be revised.** You can run text-to-video inference on a single Ascend device using the script `scripts/text_condition/single-device/sample_t2v_29x720p.sh`. ```bash diff --git a/examples/opensora_pku/examples/rec_image.py b/examples/opensora_pku/examples/rec_image.py index 5500e0d5c1..e667b506d3 100644 --- a/examples/opensora_pku/examples/rec_image.py +++ b/examples/opensora_pku/examples/rec_image.py @@ -1,38 +1,25 @@ -""" -Run causal vae reconstruction on a given image -Usage example: -python examples/rec_image.py \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae \ - --image_path test.jpg \ - --rec_path rec.jpg \ - --device Ascend \ - --short_size 512 \ - --enable_tiling -""" import argparse import logging import os import sys +import cv2 import numpy as np +from albumentations import Compose, Lambda, Resize, ToFloat from PIL import Image import mindspore as ms -from mindspore import nn mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) -from mindone.utils.amp import auto_mixed_precision + from mindone.utils.config import str2bool from mindone.utils.logger import set_logger sys.path.append(".") -import cv2 -from albumentations import Compose, Lambda, Resize, ToFloat -from opensora.models import CausalVAEModelWrapper -from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate -from opensora.utils.ms_utils import init_env +from opensora.models.causalvideovae import ae_wrapper +from opensora.npu_config import npu_config from opensora.utils.utils import get_precision logger = logging.getLogger(__name__) @@ -60,7 +47,7 @@ def preprocess(image, height: int = 128, width: int = 128): image = video_transform(image=image)["image"] # (h w c) # (h w c) -> (c h w) -> (c t h w) - image = np.transpose(image, (2, 1, 0))[:, None, :, :] + image = np.transpose(image, (2, 0, 1))[:, None, :, :] return image @@ -75,16 +62,10 @@ def transform_to_rgb(x, rescale_to_uint8=True): def main(args): image_path = args.image_path short_size = args.short_size - init_env( - mode=args.mode, - device_target=args.device, - precision_mode=args.precision_mode, - jit_level=args.jit_level, - jit_syntax_level=args.jit_syntax_level, - ) + npu_config.set_npu_env(args) set_logger(name="", output_dir=args.output_path, rank=0) - + dtype = get_precision(args.precision) if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") state_dict = ms.load_checkpoint(args.ms_checkpoint) @@ -94,8 +75,12 @@ def main(args): ) else: state_dict = None - kwarg = {"state_dict": state_dict, "use_safetensors": True} - vae = CausalVAEModelWrapper(args.ae_path, **kwarg) + kwarg = { + "state_dict": state_dict, + "use_safetensors": True, + "dtype": dtype, + } + vae = ae_wrapper[args.ae](args.ae_path, **kwarg) if args.enable_tiling: vae.vae.enable_tiling() @@ -104,25 +89,11 @@ def main(args): vae.set_train(False) for param in vae.get_parameters(): param.requires_grad = False - if args.precision in ["fp16", "bf16"]: - amp_level = "O2" - dtype = get_precision(args.precision) - if dtype == ms.float16: - custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] - else: - custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] - - vae = auto_mixed_precision(vae, amp_level, dtype, custom_fp32_cells=custom_fp32_cells) - logger.info( - f"Set mixed precision to {amp_level} with dtype={args.precision}, custom fp32_cells {custom_fp32_cells}" - ) - elif args.precision == "fp32": - dtype = get_precision(args.precision) - else: - raise ValueError(f"Unsupported precision {args.precision}") + input_x = np.array(Image.open(image_path)) # (h w c) assert input_x.shape[2], f"Expect the input image has three channels, but got shape {input_x.shape}" x_vae = preprocess(input_x, short_size, short_size) # use image as a single-frame video + x_vae = ms.Tensor(x_vae, dtype).unsqueeze(0) # b c t h w latents = vae.encode(x_vae) latents = latents.to(dtype) @@ -147,6 +118,7 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument("--image_path", type=str, default="") parser.add_argument("--rec_path", type=str, default="") + parser.add_argument("--ae", type=str, default="WFVAEModel_D8_4x8x8", choices=ae_wrapper.keys()) parser.add_argument("--ae_path", type=str, default="results/pretrained") parser.add_argument("--ms_checkpoint", type=str, default=None) parser.add_argument("--short_size", type=int, default=336) @@ -154,7 +126,7 @@ def main(args): parser.add_argument("--tile_sample_min_size", type=int, default=256) parser.add_argument("--enable_tiling", action="store_true") # ms related - parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") parser.add_argument( "--precision", default="bf16", diff --git a/examples/opensora_pku/examples/rec_video.py b/examples/opensora_pku/examples/rec_video.py index a653da3b2b..04fe503f1d 100644 --- a/examples/opensora_pku/examples/rec_video.py +++ b/examples/opensora_pku/examples/rec_video.py @@ -1,15 +1,3 @@ -""" -Run causal vae reconstruction on a given video. -Usage example: -python examples/rec_video.py \ - --ae_path path/to/vae/ckpt \ - --video_path test.mp4 \ - --rec_path rec.mp4 \ - --sample_rate 1 \ - --num_frames 65 \ - --height 480 \ - --width 640 \ -""" import argparse import logging import os @@ -21,12 +9,9 @@ from PIL import Image import mindspore as ms -from mindspore import nn mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) -from mindone.utils.amp import auto_mixed_precision -from mindone.utils.config import str2bool from mindone.utils.logger import set_logger from mindone.visualize.videos import save_videos @@ -36,9 +21,8 @@ import cv2 from albumentations import Compose, Lambda, Resize, ToFloat from opensora.dataset.transform import center_crop_th_tw -from opensora.models import CausalVAEModelWrapper -from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate -from opensora.utils.ms_utils import init_env +from opensora.models.causalvideovae import ae_wrapper +from opensora.npu_config import npu_config from opensora.utils.utils import get_precision logger = logging.getLogger(__name__) @@ -117,14 +101,8 @@ def transform_to_rgb(x, rescale_to_uint8=True): def main(args): - init_env( - mode=args.mode, - device_target=args.device, - precision_mode=args.precision_mode, - jit_level=args.jit_level, - jit_syntax_level=args.jit_syntax_level, - ) - + npu_config.set_npu_env(args) + dtype = get_precision(args.precision) set_logger(name="", output_dir=args.output_path, rank=0) if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") @@ -135,38 +113,23 @@ def main(args): ) else: state_dict = None - kwarg = {"state_dict": state_dict, "use_safetensors": True} - vae = CausalVAEModelWrapper(args.ae_path, **kwarg) + kwarg = { + "state_dict": state_dict, + "use_safetensors": True, + "dtype": dtype, + } + vae = ae_wrapper[args.ae](args.ae_path, **kwarg) if args.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor - if args.save_memory: - vae.vae.tile_sample_min_size = args.tile_sample_min_size - vae.vae.tile_latent_min_size = 32 - vae.vae.tile_sample_min_size_t = 29 - vae.vae.tile_latent_min_size_t = 8 vae.set_train(False) for param in vae.get_parameters(): param.requires_grad = False - if args.precision in ["fp16", "bf16"]: - amp_level = "O2" - dtype = get_precision(args.precision) - if dtype == ms.float16: - custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] - else: - custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] - vae = auto_mixed_precision(vae, amp_level, dtype, custom_fp32_cells=custom_fp32_cells) - logger.info( - f"Set mixed precision to {amp_level} with dtype={args.precision}, custom fp32_cells {custom_fp32_cells}" - ) - elif args.precision == "fp32": - dtype = get_precision(args.precision) - else: - raise ValueError(f"Unsupported precision {args.precision}") x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.height, args.width) + x_vae = ms.Tensor(x_vae, dtype).unsqueeze(0) # b c t h w latents = vae.encode(x_vae) latents = latents.to(dtype) @@ -202,6 +165,7 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument("--video_path", type=str, default="") parser.add_argument("--rec_path", type=str, default="") + parser.add_argument("--ae", type=str, default="") parser.add_argument("--ae_path", type=str, default="results/pretrained") parser.add_argument("--ms_checkpoint", type=str, default=None) parser.add_argument("--fps", type=int, default=30) @@ -209,12 +173,10 @@ def main(args): parser.add_argument("--width", type=int, default=336) parser.add_argument("--num_frames", type=int, default=65) parser.add_argument("--sample_rate", type=int, default=1) - parser.add_argument("--tile_overlap_factor", type=float, default=0.25) - parser.add_argument("--tile_sample_min_size", type=int, default=256) parser.add_argument("--enable_tiling", action="store_true") - parser.add_argument("--save_memory", action="store_true") + parser.add_argument("--tile_overlap_factor", type=float, default=0.25) # ms related - parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") parser.add_argument( "--precision", default="bf16", @@ -223,12 +185,7 @@ def main(args): help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \ if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.", ) - parser.add_argument( - "--vae_keep_gn_fp32", - default=False, - type=str2bool, - help="whether keep GroupNorm in fp32. Defaults to False in inference, better to set to True when training vae", - ) + parser.add_argument("--device", type=str, default="Ascend", help="Ascend or GPU") parser.add_argument( "--precision_mode", diff --git a/examples/opensora_pku/examples/rec_video_folder.py b/examples/opensora_pku/examples/rec_video_folder.py index afaaa8e874..a866277b40 100644 --- a/examples/opensora_pku/examples/rec_video_folder.py +++ b/examples/opensora_pku/examples/rec_video_folder.py @@ -7,22 +7,19 @@ from tqdm import tqdm import mindspore as ms -from mindspore import nn mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) -from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import str2bool from mindone.utils.logger import set_logger from mindone.visualize.videos import save_videos sys.path.append(".") from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info -from opensora.models import CausalVAEModelWrapper +from opensora.models.causalvideovae import ae_wrapper from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader -from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate -from opensora.utils.ms_utils import init_env +from opensora.npu_config import npu_config from opensora.utils.utils import get_precision logger = logging.getLogger(__name__) @@ -46,18 +43,8 @@ def main(args): batch_size = args.batch_size num_workers = args.num_workers assert args.dataset_name == "video", "Only support video reconstruction!" - rank_id, device_num = init_env( - args.mode, - seed=args.seed, - distributed=args.use_parallel, - device_target=args.device, - max_device_memory=args.max_device_memory, - parallel_mode=args.parallel_mode, - precision_mode=args.precision_mode, - sp_size=args.sp_size, - jit_level=args.jit_level, - jit_syntax_level=args.jit_syntax_level, - ) + rank_id, device_num = npu_config.set_npu_env(args) + dtype = get_precision(args.precision) if not os.path.exists(args.generated_video_dir): os.makedirs(args.generated_video_dir, exist_ok=True) @@ -72,42 +59,26 @@ def main(args): ) else: state_dict = None - kwarg = {"state_dict": state_dict, "use_safetensors": True} - vae = CausalVAEModelWrapper(args.ae_path, **kwarg) + kwarg = { + "state_dict": state_dict, + "use_safetensors": True, + "dtype": dtype, + } + vae = ae_wrapper[args.ae](args.ae_path, **kwarg) if args.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor - if args.save_memory: - vae.vae.tile_sample_min_size = args.tile_sample_min_size - vae.vae.tile_latent_min_size = 32 - vae.vae.tile_sample_min_size_t = 29 - vae.vae.tile_latent_min_size_t = 8 vae.set_train(False) for param in vae.get_parameters(): param.requires_grad = False - if args.precision in ["fp16", "bf16"]: - amp_level = "O2" - dtype = get_precision(args.precision) - if dtype == ms.float16: - custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] - else: - custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] - vae = auto_mixed_precision(vae, amp_level, dtype, custom_fp32_cells=custom_fp32_cells) - logger.info( - f"Set mixed precision to {amp_level} with dtype={args.precision}, custom fp32_cells {custom_fp32_cells}" - ) - elif args.precision == "fp32": - dtype = get_precision(args.precision) - else: - raise ValueError(f"Unsupported precision {args.precision}") ds_config = dict( data_file_path=args.data_file_path, video_column=args.video_column, data_folder=real_video_dir, - size=(height, width), + size=max(height, width), # SmallestMaxSize crop_size=(height, width), disable_flip=True, random_crop=False, @@ -193,6 +164,7 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument("--real_video_dir", type=str, default="") parser.add_argument("--generated_video_dir", type=str, default="") + parser.add_argument("--ae", type=str, default="") parser.add_argument("--ae_path", type=str, default="results/pretrained") parser.add_argument("--ms_checkpoint", type=str, default=None) parser.add_argument("--sample_fps", type=int, default=30) @@ -209,11 +181,9 @@ def main(args): parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--tile_overlap_factor", type=float, default=0.25) - parser.add_argument("--tile_sample_min_size", type=int, default=256) parser.add_argument("--enable_tiling", action="store_true") - parser.add_argument("--save_memory", action="store_true") parser.add_argument("--output_origin", action="store_true") - parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") parser.add_argument( "--precision", default="bf16", diff --git a/examples/opensora_pku/opensora/models/__init__.py b/examples/opensora_pku/opensora/models/__init__.py index 85a4156424..39c8d30de2 100644 --- a/examples/opensora_pku/opensora/models/__init__.py +++ b/examples/opensora_pku/opensora/models/__init__.py @@ -1 +1 @@ -from .causalvideovae.model import CausalVAEModelWrapper +from .causalvideovae import CausalVAEModelWrapper, WFVAEModelWrapper diff --git a/examples/opensora_pku/opensora/models/causalvideovae/__init__.py b/examples/opensora_pku/opensora/models/causalvideovae/__init__.py index e136aff9ac..4c968af7e9 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/__init__.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/__init__.py @@ -1,26 +1,108 @@ -from .model.causal_vae import CausalVAEModelWrapper +import logging + +import mindspore as ms +from mindspore import nn + +from .model.vae import CausalVAEModel, WFVAEModel + +logger = logging.getLogger(__name__) + + +class CausalVAEModelWrapper(nn.Cell): + def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs): + super(CausalVAEModelWrapper, self).__init__() + # if os.path.exists(ckpt): + # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) + self.vae, loading_info = CausalVAEModel.from_pretrained( + model_path, subfolder=subfolder, cache_dir=cache_dir, output_loading_info=True, **kwargs + ) + logger.info(loading_info) + if use_ema: + self.vae.init_from_ema(model_path) + self.vae = self.vae.ema + + def encode(self, x): # b c t h w + # x = self.vae.encode(x) + x = self.vae.encode(x) * 0.18215 + return x + + def decode(self, x): + # x = self.vae.decode(x) + x = self.vae.decode(x / 0.18215) + # b c t h w -> b t c h w + x = x.permute(0, 2, 1, 3, 4) + return x + + def dtype(self): + return self.vae.dtype + + +class WFVAEModelWrapper(nn.Cell): + def __init__(self, model_path, dtype=ms.float32, subfolder=None, cache_dir=None, **kwargs): + super(WFVAEModelWrapper, self).__init__() + self.vae = WFVAEModel.from_pretrained( + model_path, subfolder=subfolder, cache_dir=cache_dir, dtype=dtype, **kwargs + ) + self.shift = ms.Tensor(self.vae.config.shift)[None, :, None, None, None] + self.scale = ms.Tensor(self.vae.config.scale)[None, :, None, None, None] + + def encode(self, x): + x = (self.vae.encode(x) - self.shift.to(dtype=x.dtype)) * self.scale.to(dtype=x.dtype) + return x + + def decode(self, x): + x = x / self.scale.to(dtype=x.dtype) + self.shift.to(dtype=x.dtype) + x = self.vae.decode(x) + # b c t h w -> b t c h w + x = x.transpose(0, 2, 1, 3, 4) + return x + + def dtype(self): + return self.vae.dtype + + +ae_wrapper = { + "CausalVAEModel_D4_2x8x8": CausalVAEModelWrapper, + "CausalVAEModel_D8_2x8x8": CausalVAEModelWrapper, + "CausalVAEModel_D4_4x8x8": CausalVAEModelWrapper, + "CausalVAEModel_D8_4x8x8": CausalVAEModelWrapper, + "WFVAEModel_D8_4x8x8": WFVAEModelWrapper, + "WFVAEModel_D16_4x8x8": WFVAEModelWrapper, + "WFVAEModel_D32_4x8x8": WFVAEModelWrapper, + "WFVAEModel_D32_8x8x8": WFVAEModelWrapper, +} ae_stride_config = { "CausalVAEModel_D4_2x8x8": [2, 8, 8], "CausalVAEModel_D8_2x8x8": [2, 8, 8], "CausalVAEModel_D4_4x8x8": [4, 8, 8], "CausalVAEModel_D8_4x8x8": [4, 8, 8], + "WFVAEModel_D8_4x8x8": [4, 8, 8], + "WFVAEModel_D16_4x8x8": [4, 8, 8], + "WFVAEModel_D32_4x8x8": [4, 8, 8], + "WFVAEModel_D32_8x8x8": [8, 8, 8], } - ae_channel_config = { "CausalVAEModel_D4_2x8x8": 4, "CausalVAEModel_D8_2x8x8": 8, "CausalVAEModel_D4_4x8x8": 4, "CausalVAEModel_D8_4x8x8": 8, + "WFVAEModel_D8_4x8x8": 8, + "WFVAEModel_D16_4x8x8": 16, + "WFVAEModel_D32_4x8x8": 32, + "WFVAEModel_D32_8x8x8": 32, } - ae_denorm = { "CausalVAEModel_D4_2x8x8": lambda x: (x + 1.0) / 2.0, "CausalVAEModel_D8_2x8x8": lambda x: (x + 1.0) / 2.0, "CausalVAEModel_D4_4x8x8": lambda x: (x + 1.0) / 2.0, "CausalVAEModel_D8_4x8x8": lambda x: (x + 1.0) / 2.0, + "WFVAEModel_D8_4x8x8": lambda x: (x + 1.0) / 2.0, + "WFVAEModel_D16_4x8x8": lambda x: (x + 1.0) / 2.0, + "WFVAEModel_D32_4x8x8": lambda x: (x + 1.0) / 2.0, + "WFVAEModel_D32_8x8x8": lambda x: (x + 1.0) / 2.0, } ae_norm = { @@ -28,4 +110,8 @@ "CausalVAEModel_D8_2x8x8": lambda x: 2.0 * x - 1.0, "CausalVAEModel_D4_4x8x8": lambda x: 2.0 * x - 1.0, "CausalVAEModel_D8_4x8x8": lambda x: 2.0 * x - 1.0, + "WFVAEModel_D8_4x8x8": lambda x: 2.0 * x - 1.0, + "WFVAEModel_D16_4x8x8": lambda x: 2.0 * x - 1.0, + "WFVAEModel_D32_4x8x8": lambda x: 2.0 * x - 1.0, + "WFVAEModel_D32_8x8x8": lambda x: 2.0 * x - 1.0, } diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_fvd.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_fvd.py index a702327b2a..7628e51c5a 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_fvd.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_fvd.py @@ -1,15 +1,15 @@ from tqdm import tqdm -from mindspore import ops +from mindspore import mint def trans(x): # if greyscale images add channel if x.shape[-3] == 1: - x = x.repeat(1, 1, 3, 1, 1) + x = x.repeat(3, axis=2) # permute BTCHW -> BCTHW - x = x.permute(0, 2, 1, 3, 4) + x = x.transpose(0, 2, 1, 3, 4) return x @@ -77,8 +77,8 @@ def main(): VIDEO_LENGTH = 50 CHANNEL = 3 SIZE = 64 - videos1 = ops.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) - videos2 = ops.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) + videos1 = mint.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) + videos2 = mint.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) import json diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_lpips.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_lpips.py index 31f7e88368..99942a23ab 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_lpips.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_lpips.py @@ -4,7 +4,7 @@ from opensora.models.causalvideovae.model.losses.lpips import LPIPS import mindspore as ms -from mindspore import ops +from mindspore import mint spatial = True # Return a spatial map of perceptual distance. lpips_ckpt_path = os.path.join("pretrained", "lpips_vgg-426bf45c.ckpt") @@ -92,14 +92,14 @@ def main(): VIDEO_LENGTH = 50 CHANNEL = 3 SIZE = 64 - videos1 = ops.zeros( + videos1 = mint.zeros( NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, ) - videos2 = ops.ones( + videos2 = mint.ones( NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_psnr.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_psnr.py index a12ffa351a..088a38ed0d 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_psnr.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_psnr.py @@ -2,7 +2,7 @@ import numpy as np -from mindspore import ops +from mindspore import mint def img_psnr(img1, img2): @@ -77,8 +77,8 @@ def main(): VIDEO_LENGTH = 50 CHANNEL = 3 SIZE = 64 - videos1 = ops.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) - videos2 = ops.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) + videos1 = mint.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) + videos2 = mint.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) import json diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_ssim.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_ssim.py index 4ad92deff8..5b174ffdba 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_ssim.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/cal_ssim.py @@ -1,7 +1,7 @@ import cv2 import numpy as np -from mindspore import ops +from mindspore import mint def ssim(img1, img2): @@ -103,8 +103,8 @@ def main(): VIDEO_LENGTH = 50 CHANNEL = 3 SIZE = 64 - videos1 = ops.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) - videos2 = ops.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) + videos1 = mint.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) + videos2 = mint.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE) import json diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/eval.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/eval.py new file mode 100644 index 0000000000..0993a4a543 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/eval.py @@ -0,0 +1,262 @@ +"""Calculates the CLIP Scores + +The CLIP model is a contrasitively learned language-image model. There is +an image encoder and a text encoder. It is believed that the CLIP model could +measure the similarity of cross modalities. Please find more information from +https://github.com/openai/CLIP. + +The CLIP Score measures the Cosine Similarity between two embedded features. +This repository utilizes the pretrained CLIP Model to calculate +the mean average of cosine similarities. + +See --help to see further details. + +Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP. + +Copyright 2023 The Hong Kong Polytechnic University + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +import os +import os.path as osp +import sys +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +import numpy as np +from decord import VideoReader + +mindone_lib_path = os.path.abspath("../../") +sys.path.insert(0, mindone_lib_path) +sys.path.append(".") +# from opensora.eval.cal_fvd import calculate_fvd +from opensora.eval.cal_lpips import calculate_lpips +from opensora.eval.cal_psnr import calculate_psnr + +try: + from opensora.eval.cal_flolpips import calculate_flolpips + + flolpips_isavailable = True +except Exception: + flolpips_isavailable = False +from opensora.eval.cal_ssim import calculate_ssim +from opensora.models.causalvideovae.model.dataset_videobase import create_dataloader +from opensora.utils.dataset_utils import create_video_transforms +from tqdm import tqdm + + +class VideoDataset: + def __init__( + self, + real_video_dir, + generated_video_dir, + num_frames, + sample_rate=1, + crop_size=None, + resolution=128, + output_columns=["real", "generated"], + ) -> None: + super().__init__() + self.real_video_files = self.combine_without_prefix(real_video_dir) + self.generated_video_files = self.combine_without_prefix(generated_video_dir) + assert ( + len(self.real_video_files) == len(self.generated_video_files) and len(self.real_video_files) > 0 + ), "Expect that the real and generated folders are not empty and contain the equal number of videos!" + self.num_frames = num_frames + self.sample_rate = sample_rate + self.crop_size = crop_size + self.short_size = resolution + self.output_columns = output_columns + + self.pixel_transforms = create_video_transforms( + size=self.short_size, + crop_size=crop_size, + random_crop=False, + disable_flip=True, + num_frames=num_frames, + backend="al", + ) + + def __len__(self): + return len(self.real_video_files) + + def __getitem__(self, index): + if index >= len(self): + raise IndexError + real_video_file = self.real_video_files[index] + generated_video_file = self.generated_video_files[index] + if os.path.basename(real_video_file).split(".")[0] != os.path.basename(generated_video_file).split(".")[0]: + print( + f"Warning! video file name mismatch! real and generated {os.path.basename(real_video_file)} and {os.path.basename(generated_video_file)}" + ) + real_video_tensor = self._load_video(real_video_file) + generated_video_tensor = self._load_video(generated_video_file) + return real_video_tensor.astype(np.float32), generated_video_tensor.astype(np.float32) + + def _load_video(self, video_path): + num_frames = self.num_frames + sample_rate = self.sample_rate + decord_vr = VideoReader( + video_path, + ) + total_frames = len(decord_vr) + sample_frames_len = sample_rate * num_frames + + if total_frames >= sample_frames_len: + s = 0 + e = s + sample_frames_len + num_frames = num_frames + else: + s = 0 + e = total_frames + num_frames = int(total_frames / sample_frames_len * num_frames) + print(f"Video total number of frames {total_frames} is less than the target num_frames {sample_frames_len}") + print(video_path) + + frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) + pixel_values = decord_vr.get_batch(frame_id_list).asnumpy() + # video_data = video_data.transpose(0, 3, 1, 2) # (T, H, W, C) -> (C, T, H, W) + # NOTE:it's to ensure augment all frames in a video in the same way. + # ref: https://albumentations.ai/docs/examples/example_multi_target/ + + inputs = {"image": pixel_values[0]} + for i in range(num_frames - 1): + inputs[f"image{i}"] = pixel_values[i + 1] + + output = self.pixel_transforms(**inputs) + + pixel_values = np.stack(list(output.values()), axis=0) + # (t h w c) -> (t c h w) + pixel_values = np.transpose(pixel_values, (0, 3, 1, 2)) + pixel_values = pixel_values / 255.0 + return pixel_values + + def combine_without_prefix(self, folder_path, prefix="."): + folder = [] + assert os.path.exists(folder_path), f"Expect that {folder_path} exist!" + for name in os.listdir(folder_path): + if name[0] == prefix: + continue + if osp.isfile(osp.join(folder_path, name)): + folder.append(osp.join(folder_path, name)) + folder.sort() + return folder + + +def calculate_common_metric(args, dataloader, dataset_size): + score_list = [] + index = 0 + for batch_data in tqdm( + dataloader, total=dataset_size + ): # {'real': real_video_tensor, 'generated':generated_video_tensor } + real_videos = batch_data["real"] + generated_videos = batch_data["generated"] + assert real_videos.shape[2] == generated_videos.shape[2] + if args.metric == "fvd": + if index == 0: + print("calculate fvd...") + raise ValueError + # tmp_list = list(calculate_fvd(real_videos, generated_videos, method=args.fvd_method)["value"].values()) + elif args.metric == "ssim": + if index == 0: + print("calculate ssim...") + tmp_list = list(calculate_ssim(real_videos, generated_videos)["value"].values()) + elif args.metric == "psnr": + if index == 0: + print("calculate psnr...") + tmp_list = list(calculate_psnr(real_videos, generated_videos)["value"].values()) + elif args.metric == "flolpips": + if flolpips_isavailable: + result = calculate_flolpips( + real_videos, + generated_videos, + ) + tmp_list = list(result["value"].values()) + else: + continue + else: + if index == 0: + print("calculate_lpips...") + tmp_list = list( + calculate_lpips( + real_videos, + generated_videos, + )["value"].values() + ) + index += 1 + score_list += tmp_list + return np.mean(score_list) + + +def main(): + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument("--batch_size", type=int, default=2, help="Batch size to use") + parser.add_argument("--real_video_dir", type=str, help=("the path of real videos`")) + parser.add_argument("--generated_video_dir", type=str, help=("the path of generated videos`")) + parser.add_argument("--device", type=str, default=None, help="Device to use. Like GPU or Ascend") + parser.add_argument( + "--num_workers", + type=int, + default=8, + help=("Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`"), + ) + parser.add_argument("--sample_fps", type=int, default=30) + parser.add_argument("--resolution", type=int, default=336) + parser.add_argument("--crop_size", type=int, default=None) + parser.add_argument("--num_frames", type=int, default=100) + parser.add_argument("--sample_rate", type=int, default=1) + parser.add_argument("--subset_size", type=int, default=None) + parser.add_argument("--metric", type=str, default="fvd", choices=["fvd", "psnr", "ssim", "lpips", "flolpips"]) + parser.add_argument("--fvd_method", type=str, default="styleganv", choices=["styleganv", "videogpt"]) + + args = parser.parse_args() + + if args.num_workers is None: + try: + num_cpus = len(os.sched_getaffinity(0)) + except AttributeError: + # os.sched_getaffinity is not available under Windows, use + # os.cpu_count instead (which may not return the *available* number + # of CPUs). + num_cpus = os.cpu_count() + + num_workers = min(num_cpus, 8) if num_cpus is not None else 0 + else: + num_workers = args.num_workers + + dataset = VideoDataset( + args.real_video_dir, + args.generated_video_dir, + num_frames=args.num_frames, + sample_rate=args.sample_rate, + crop_size=args.crop_size, + resolution=args.resolution, + ) + + dataloader = create_dataloader( + dataset, + batch_size=args.batch_size, + ds_name="video", + num_parallel_workers=num_workers, + shuffle=False, + drop_remainder=False, + ) + dataset_size = math.ceil(len(dataset) / float(args.batch_size)) + dataloader = dataloader.create_dict_iterator(1, output_numpy=True) + metric_score = calculate_common_metric(args, dataloader, dataset_size) + print("metric: ", args.metric, " ", metric_score) + + +if __name__ == "__main__": + main() diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py new file mode 100644 index 0000000000..fa42d76f46 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py @@ -0,0 +1,122 @@ +import math +import os + +from mindspore import context, export, load, mint, nn, ops + +try: + import torch +except ImportError: + print( + "For the first-time running, torch is required to load torchscript model and convert to onnx, but import torch leads to an ImportError!" + ) + +# https://github.com/universome/fvd-comparison + + +def load_i3d_pretrained(bs=1): + i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt" + filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "i3d_torchscript.pt") + onnx_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "i3d_torchscript.onnx") + mindir_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "i3d_torchscript.mindir") + if not os.path.exists(mindir_filepath): + if not os.path.exists(filepath): + print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") + os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") + if not os.path.exists(onnx_filepath): + # convert torch jit model to onnx model + model = torch.jit.load(filepath).eval() + dummy_input = torch.randn(bs, 3, 224, 224) + # Export the model to ONNX + torch.onnx.export(model, dummy_input, onnx_filepath, export_params=True, opset_version=11) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + export(onnx_filepath, mindir_filepath, file_format="MINDIR") + # + graph = load(mindir_filepath) + model = nn.GraphCell(graph) + model.set_train(False) + for param in model.get_parameters(): + param.requires_grad = False + + return model + + +def get_feats(videos, detector, bs=10): + # videos : torch.tensor BCTHW [0, 1] + detector_kwargs = dict( + rescale=False, resize=False, return_features=True + ) # Return raw features before the softmax layer. + feats = np.empty((0, 400)) + + for i in range((len(videos) - 1) // bs + 1): + feats = np.vstack( + [ + feats, + detector( + mint.stack([preprocess_single(video) for video in videos[i * bs : (i + 1) * bs]]), **detector_kwargs + ).asnumpy(), + ] + ) + return feats + + +def get_fvd_feats(videos, i3d, bs=10): + # videos in [0, 1] as torch tensor BCTHW + # videos = [preprocess_single(video) for video in videos] + embeddings = get_feats(videos, i3d, bs) + return embeddings + + +def preprocess_single(video, resolution=224, sequence_length=None): + # video: CTHW, [0, 1] + c, t, h, w = video.shape + + # temporal crop + if sequence_length is not None: + assert sequence_length <= t + video = video[:, :sequence_length] + + # scale shorter side to resolution + scale = resolution / min(h, w) + if h < w: + target_size = (resolution, math.ceil(w * scale)) + else: + target_size = (math.ceil(h * scale), resolution) + video = ops.interpolate(video, size=target_size, mode="bilinear", align_corners=False) + + # center crop + c, t, h, w = video.shape + w_start = (w - resolution) // 2 + h_start = (h - resolution) // 2 + video = video[:, :, h_start : h_start + resolution, w_start : w_start + resolution] + + # [0, 1] -> [-1, 1] + video = (video - 0.5) * 2 + + return video + + +""" +Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py +""" +from typing import Tuple + +import numpy as np +from scipy.linalg import sqrtm + + +def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + mu = feats.mean(axis=0) # [d] + sigma = np.cov(feats, rowvar=False) # [d, d] + return mu, sigma + + +def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: + mu_gen, sigma_gen = compute_stats(feats_fake) + mu_real, sigma_real = compute_stats(feats_real) + m = np.square(mu_gen - mu_real).sum() + if feats_fake.shape[0] > 1: + s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) + else: + fid = np.real(m) + return float(fid) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/videogpt/fvd.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/videogpt/fvd.py new file mode 100644 index 0000000000..67aaa46198 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/videogpt/fvd.py @@ -0,0 +1,164 @@ +import math +import os + +import numpy as np + +import mindspore as ms +from mindspore import Tensor, mint, ops + +try: + import torch +except ImportError: + print( + "For the first-time running, torch is required to load torchscript model and convert to onnx, but import torch leads to an ImportError!" + ) + + +def load_i3d_pretrained(): + i3D_WEIGHTS_URL = ( + "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI" + ) + filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "i3d_pretrained_400.pt") + ms_filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "i3d_pretrained_400.ckpt") + if not os.path.exists(ms_filepath): + if not os.path.exists(filepath): + print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.") + os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}") + # convert torch ckpt to mindspore ckpt + state_dict = torch.load_state_dict(torch.load(filepath)) + raise ValueError("Not converted") + from .ms_i3d import InceptionI3d + + model = InceptionI3d(400, in_channels=3) + state_dict = ms.load_checkpoint(ms_filepath) + m, u = ms.load_param_into_net(model, state_dict) + print("net param not load: ", m, len(m)) + print("ckpt param not load: ", u, len(u)) + + model.set_train(False) + for param in model.get_parameters(): + param.requires_grad = False + return model + + +def preprocess_single(video, resolution, sequence_length=None): + # video: THWC, {0, ..., 255} + video = video.transpose(0, 3, 1, 2).float() / 255.0 # TCHW + t, c, h, w = video.shape + + # temporal crop + if sequence_length is not None: + assert sequence_length <= t + video = video[:sequence_length] + + # scale shorter side to resolution + scale = resolution / min(h, w) + if h < w: + target_size = (resolution, math.ceil(w * scale)) + else: + target_size = (math.ceil(h * scale), resolution) + video = ops.interpolate(video, size=target_size, mode="bilinear", align_corners=False) + + # center crop + t, c, h, w = video.shape + w_start = (w - resolution) // 2 + h_start = (h - resolution) // 2 + video = video[:, :, h_start : h_start + resolution, w_start : w_start + resolution] + video = video.transpose(1, 0, 2, 3) # CTHW + + video -= 0.5 + + return video + + +def preprocess(videos, target_resolution=224): + # we should tras videos in [0-1] [b c t h w] as th.float + # -> videos in {0, ..., 255} [b t h w c] as np.uint8 array + # b c t h w -> b t h w c + videos = videos.transpose(0, 2, 3, 4, 1) + videos = (videos * 255).numpy().astype(np.uint8) + + b, t, h, w, c = videos.shape + videos = Tensor(videos) + videos = mint.stack([preprocess_single(video, target_resolution) for video in videos]) + return videos * 2 # [-0.5, 0.5] -> [-1, 1] + + +def get_fvd_logits(videos, i3d, bs=10): + videos = preprocess(videos) + embeddings = get_logits(i3d, videos, bs=bs) + return embeddings + + +# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 +def _symmetric_matrix_square_root(mat, eps=1e-10): + s, u, v = ops.svd(mat) + si = mint.where(s < eps, s, mint.sqrt(s)) + return mint.matmul(mint.matmul(u, ops.diag(si)), v.t()) + + +# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 +def trace_sqrt_product(sigma, sigma_v): + sqrt_sigma = _symmetric_matrix_square_root(sigma) + sqrt_a_sigmav_a = mint.matmul(sqrt_sigma, mint.matmul(sigma_v, sqrt_sigma)) + return mint.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) + + +# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 +def cov(m, rowvar=False): + """Estimate a covariance matrix given data. + + Covariance indicates the level to which two variables vary together. + If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, + then the covariance matrix element `C_{ij}` is the covariance of + `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. + + Args: + m: A 1-D or 2-D array containing multiple variables and observations. + Each row of `m` represents a variable, and each column a single + observation of all those variables. + rowvar: If `rowvar` is True, then each row represents a + variable, with observations in the columns. Otherwise, the + relationship is transposed: each column represents a variable, + while the rows contain observations. + + Returns: + The covariance matrix of the variables. + """ + if m.ndim > 2: + raise ValueError("m has more than 2 dimensions") + if m.ndim < 2: + m = m.view(1, -1) + if not rowvar and m.shape[0] != 1: + m = m.t() + + fact = 1.0 / (m.shape[1] - 1) # unbiased estimate + m -= mint.mean(m, dim=1, keepdim=True) + mt = m.t() # if complex: mt = m.t().conj() + return fact * m.matmul(mt).squeeze() + + +def frechet_distance(x1, x2): + x1 = x1.flatten(start_dim=1) + x2 = x2.flatten(start_dim=1) + m, m_w = x1.mean(dim=0), x2.mean(dim=0) + sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) + mean = mint.sum((m - m_w) ** 2) + if x1.shape[0] > 1: + sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) + trace = mint.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component + fd = trace + mean + else: + fd = np.real(mean) + return float(fd) + + +def get_logits(i3d, videos, bs=10): + # assert videos.shape[0] % 16 == 0 + logits = [] + for i in range(0, videos.shape[0], bs): + batch = videos[i : i + bs] + # logits.append(i3d.module.extract_features(batch)) # wrong + logits.append(i3d(batch)) # right + logits = mint.cat(logits, dim=0) + return logits diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/videogpt/ms_i3d.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/videogpt/ms_i3d.py new file mode 100644 index 0000000000..5c284cd443 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/videogpt/ms_i3d.py @@ -0,0 +1,382 @@ +# Original code from https://github.com/piergiaj/pytorch-i3d +from opensora.npu_config import npu_config + +from mindspore import mint, nn + + +class MaxPool3dSamePadding(nn.MaxPool3d): + def compute_pad(self, dim, s): + if s % self.stride[dim] == 0: + return max(self.kernel_size[dim] - self.stride[dim], 0) + else: + return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) + + def construct(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.shape + + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = mint.nn.functional.pad(x, pad) + return super(MaxPool3dSamePadding, self).construct(x) + + +class Unit3D(nn.Cell): + def __init__( + self, + in_channels, + output_channels, + kernel_shape=(1, 1, 1), + stride=(1, 1, 1), + padding=0, + activation_fn=mint.nn.functional.relu, + use_batch_norm=True, + use_bias=False, + name="unit_3d", + ): + """Initializes Unit3D module.""" + super(Unit3D, self).__init__() + + self._output_channels = output_channels + self._kernel_shape = kernel_shape + self._stride = stride + self._use_batch_norm = use_batch_norm + self._activation_fn = activation_fn + self._use_bias = use_bias + self.name = name + self.padding = padding + + self.conv3d = nn.Conv3d( + in_channels=in_channels, + out_channels=self._output_channels, + kernel_size=self._kernel_shape, + stride=self._stride, + padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in construct function + has_bias=self._use_bias, + ) + + if self._use_batch_norm: + self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001) + + def compute_pad(self, dim, s): + if s % self._stride[dim] == 0: + return max(self._kernel_shape[dim] - self._stride[dim], 0) + else: + return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) + + def construct(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.shape + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = mint.nn.functional.pad(x, pad) + if npu_config is not None and npu_config.on_npu: + x = npu_config.run_conv3d(self.conv3d, x, x.dtype) + else: + x = self.conv3d(x) + + if self._use_batch_norm: + x = self.bn(x) + if self._activation_fn is not None: + x = self._activation_fn(x) + return x + + +class InceptionModule(nn.Cell): + def __init__(self, in_channels, out_channels, name): + super(InceptionModule, self).__init__() + + self.b0 = Unit3D( + in_channels=in_channels, + output_channels=out_channels[0], + kernel_shape=[1, 1, 1], + padding=0, + name=name + "/Branch_0/Conv3d_0a_1x1", + ) + self.b1a = Unit3D( + in_channels=in_channels, + output_channels=out_channels[1], + kernel_shape=[1, 1, 1], + padding=0, + name=name + "/Branch_1/Conv3d_0a_1x1", + ) + self.b1b = Unit3D( + in_channels=out_channels[1], + output_channels=out_channels[2], + kernel_shape=[3, 3, 3], + name=name + "/Branch_1/Conv3d_0b_3x3", + ) + self.b2a = Unit3D( + in_channels=in_channels, + output_channels=out_channels[3], + kernel_shape=[1, 1, 1], + padding=0, + name=name + "/Branch_2/Conv3d_0a_1x1", + ) + self.b2b = Unit3D( + in_channels=out_channels[3], + output_channels=out_channels[4], + kernel_shape=[3, 3, 3], + name=name + "/Branch_2/Conv3d_0b_3x3", + ) + self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(1, 1, 1), padding=0) + self.b3b = Unit3D( + in_channels=in_channels, + output_channels=out_channels[5], + kernel_shape=[1, 1, 1], + padding=0, + name=name + "/Branch_3/Conv3d_0b_1x1", + ) + self.name = name + + def construct(self, x): + b0 = self.b0(x) + b1 = self.b1b(self.b1a(x)) + b2 = self.b2b(self.b2a(x)) + b3 = self.b3b(self.b3a(x)) + return mint.cat([b0, b1, b2, b3], dim=1) + + +class InceptionI3d(nn.Cell): + """Inception-v1 I3D architecture. + The model is introduced in: + Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset + Joao Carreira, Andrew Zisserman + https://arxiv.org/pdf/1705.07750v1.pdf. + See also the Inception architecture, introduced in: + Going deeper with convolutions + Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, + Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. + http://arxiv.org/pdf/1409.4842v1.pdf. + """ + + # Endpoints of the model in order. During construction, all the endpoints up + # to a designated `final_endpoint` are returned in a dictionary as the + # second return value. + VALID_ENDPOINTS = ( + "Conv3d_1a_7x7", + "MaxPool3d_2a_3x3", + "Conv3d_2b_1x1", + "Conv3d_2c_3x3", + "MaxPool3d_3a_3x3", + "Mixed_3b", + "Mixed_3c", + "MaxPool3d_4a_3x3", + "Mixed_4b", + "Mixed_4c", + "Mixed_4d", + "Mixed_4e", + "Mixed_4f", + "MaxPool3d_5a_2x2", + "Mixed_5b", + "Mixed_5c", + "Logits", + "Predictions", + ) + + def __init__( + self, + num_classes=400, + spatial_squeeze=True, + final_endpoint="Logits", + name="inception_i3d", + in_channels=3, + dropout_keep_prob=0.5, + ): + """Initializes I3D model instance. + Args: + num_classes: The number of outputs in the logit layer (default 400, which + matches the Kinetics dataset). + spatial_squeeze: Whether to squeeze the spatial dimensions for the logits + before returning (default True). + final_endpoint: The model contains many possible endpoints. + `final_endpoint` specifies the last endpoint for the model to be built + up to. In addition to the output at `final_endpoint`, all the outputs + at endpoints up to `final_endpoint` will also be returned, in a + dictionary. `final_endpoint` must be one of + InceptionI3d.VALID_ENDPOINTS (default 'Logits'). + name: A string (optional). The name of this module. + Raises: + ValueError: if `final_endpoint` is not recognized. + """ + + if final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError("Unknown final endpoint %s" % final_endpoint) + + super(InceptionI3d, self).__init__() + self._num_classes = num_classes + self._spatial_squeeze = spatial_squeeze + self._final_endpoint = final_endpoint + self.logits = None + + if self._final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError("Unknown final endpoint %s" % self._final_endpoint) + + self.end_points = {} + end_point = "Conv3d_1a_7x7" + self.end_points[end_point] = Unit3D( + in_channels=in_channels, + output_channels=64, + kernel_shape=[7, 7, 7], + stride=(2, 2, 2), + padding=(3, 3, 3), + name=name + end_point, + ) + if self._final_endpoint == end_point: + return + + end_point = "MaxPool3d_2a_3x3" + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = "Conv3d_2b_1x1" + self.end_points[end_point] = Unit3D( + in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, name=name + end_point + ) + if self._final_endpoint == end_point: + return + + end_point = "Conv3d_2c_3x3" + self.end_points[end_point] = Unit3D( + in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, name=name + end_point + ) + if self._final_endpoint == end_point: + return + + end_point = "MaxPool3d_3a_3x3" + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = "Mixed_3b" + self.end_points[end_point] = InceptionModule(192, [64, 96, 128, 16, 32, 32], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = "Mixed_3c" + self.end_points[end_point] = InceptionModule(256, [128, 128, 192, 32, 96, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = "MaxPool3d_4a_3x3" + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = "Mixed_4b" + self.end_points[end_point] = InceptionModule(128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = "Mixed_4c" + self.end_points[end_point] = InceptionModule(192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = "Mixed_4d" + self.end_points[end_point] = InceptionModule(160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = "Mixed_4e" + self.end_points[end_point] = InceptionModule(128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = "Mixed_4f" + self.end_points[end_point] = InceptionModule( + 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], name + end_point + ) + if self._final_endpoint == end_point: + return + + end_point = "MaxPool3d_5a_2x2" + self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = "Mixed_5b" + self.end_points[end_point] = InceptionModule( + 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128], name + end_point + ) + if self._final_endpoint == end_point: + return + + end_point = "Mixed_5c" + self.end_points[end_point] = InceptionModule( + 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], name + end_point + ) + if self._final_endpoint == end_point: + return + + end_point = "Logits" + self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1)) + self.dropout = nn.Dropout(p=dropout_keep_prob) + self.logits = Unit3D( + in_channels=384 + 384 + 128 + 128, + output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name="logits", + ) + + self.build() + + def replace_logits(self, num_classes): + self._num_classes = num_classes + self.logits = Unit3D( + in_channels=384 + 384 + 128 + 128, + output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name="logits", + ) + + def build(self): + for k in self.end_points.keys(): + self.add_module(k, self.end_points[k]) + + def construct(self, x): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point](x) # use _modules to work with dataparallel + + x = self.logits(self.dropout(self.avg_pool(x))) + if self._spatial_squeeze: + logits = x.squeeze(3).squeeze(3) + logits = logits.mean(axis=2) + # logits is batch X time X classes, which is what we want to work with + return logits + + def extract_features(self, x): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point](x) + return self.avg_pool(x) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_fvd.sh b/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_fvd.sh new file mode 100644 index 0000000000..d4d4c730b0 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_fvd.sh @@ -0,0 +1,9 @@ +python opensora/eval/eval.py \ + --real_video_dir /data/xiaogeng_liu/data/video1 \ + --generated_video_dir /data/xiaogeng_liu/data/video2 \ + --batch_size 10 \ + --crop_size 64 \ + --num_frames 20 \ + --device 'Ascend' \ + --metric 'fvd' \ + --fvd_method 'styleganv' diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_lpips.sh b/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_lpips.sh index df4d800bfd..c6c1c3abb9 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_lpips.sh +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_lpips.sh @@ -1,4 +1,4 @@ -python opensora/eval/eval_common_metrics.py \ +python opensora/eval/eval.py \ --real_video_dir /data/xiaogeng_liu/data/video1 \ --generated_video_dir /data/xiaogeng_liu/data/video2 \ --batch_size 10 \ diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_psnr.sh b/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_psnr.sh index 9e43f9b3e6..0a301b82f4 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_psnr.sh +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_psnr.sh @@ -1,4 +1,4 @@ -python opensora/eval/eval_common_metrics.py \ +python opensora/eval/eval.py \ --real_video_dir /data/xiaogeng_liu/data/video1 \ --generated_video_dir /data/xiaogeng_liu/data/video2 \ --batch_size 10 \ diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_ssim.sh b/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_ssim.sh index 2cde07b428..bad324c9a7 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_ssim.sh +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/script/cal_ssim.sh @@ -1,4 +1,4 @@ -python opensora/eval/eval_common_metrics.py \ +python opensora/eval/eval.py \ --real_video_dir /data/xiaogeng_liu/data/video1 \ --generated_video_dir /data/xiaogeng_liu/data/video2 \ --batch_size 10 \ diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/__init__.py b/examples/opensora_pku/opensora/models/causalvideovae/model/__init__.py index 5d3351be6e..9713ad5977 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/__init__.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/__init__.py @@ -9,44 +9,5 @@ # CausalVQVAETrainer, # CausalVQVAEModel, CausalVQVAEModelWrapper # ) -from .causal_vae import CausalVAEModelWrapper -from .causal_vae.modeling_causalvae import CausalVAEModel -from .ema_model import EMA - -videobase_ae_stride = { - "CausalVAEModel_4x8x8": [4, 8, 8], - # 'CausalVQVAEModel_4x4x4': [4, 4, 4], - # 'CausalVQVAEModel_4x8x8': [4, 8, 8], - # 'VQVAEModel_4x4x4': [4, 4, 4], - # 'OpenVQVAEModel_4x4x4': [4, 4, 4], - # 'VQVAEModel_4x8x8': [4, 8, 8], - # 'bair_stride4x2x2': [4, 2, 2], - # 'ucf101_stride4x4x4': [4, 4, 4], - # 'kinetics_stride4x4x4': [4, 4, 4], - # 'kinetics_stride2x4x4': [2, 4, 4], -} - -videobase_ae_channel = { - "CausalVAEModel_4x8x8": 4, - # 'CausalVQVAEModel_4x4x4': 4, - # 'CausalVQVAEModel_4x8x8': 4, - # 'VQVAEModel_4x4x4': 4, - # 'OpenVQVAEModel_4x4x4': 4, - # 'VQVAEModel_4x8x8': 4, - # 'bair_stride4x2x2': 256, - # 'ucf101_stride4x4x4': 256, - # 'kinetics_stride4x4x4': 256, - # 'kinetics_stride2x4x4': 256, -} - -videobase_ae = { - "CausalVAEModel_4x8x8": CausalVAEModelWrapper, - # 'CausalVQVAEModel_4x4x4': CausalVQVAEModelWrapper, - # 'CausalVQVAEModel_4x8x8': CausalVQVAEModelWrapper, - # 'VQVAEModel_4x4x4': VQVAEModelWrapper, - # 'VQVAEModel_4x8x8': VQVAEModelWrapper, - # "bair_stride4x2x2": VQVAEModelWrapper, - # "ucf101_stride4x4x4": VQVAEModelWrapper, - # "kinetics_stride4x4x4": VQVAEModelWrapper, - # "kinetics_stride2x4x4": VQVAEModelWrapper, -} +from .registry import ModelRegistry +from .vae import CausalVAEModel, WFVAEModel diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/dataset_videobase.py b/examples/opensora_pku/opensora/models/causalvideovae/model/dataset_videobase.py index 80863fb0c2..67b7b0100b 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/dataset_videobase.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/dataset_videobase.py @@ -63,6 +63,7 @@ def __init__( self.read_from_data_file = False self.length = len(self.dataset) + assert self.length > 0, "The input dataset must not be empty!" logger.info(f"Num data samples: {self.length}") logger.info(f"sample_n_frames: {sample_n_frames}") diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/__init__.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/__init__.py index e69de29bb2..3a7b6e21c3 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/__init__.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/__init__.py @@ -0,0 +1 @@ +from .perceptual_loss import LPIPSWithDiscriminator3D diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py index 255b396f03..03c28c8967 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py @@ -1,8 +1,10 @@ import functools from typing import Tuple, Union +from opensora.npu_config import npu_config + import mindspore as ms -from mindspore import nn, ops +from mindspore import nn def weights_init(m): @@ -14,124 +16,6 @@ def weights_init(m): nn.init.constant_(m.bias.data, 0) -class Conv3d(nn.Cell): - def __init__( - self, - input_nc, - ndf, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]] = 1, - pad_mode: str = "pad", - padding: int = 0, - has_bias: bool = True, - dtype=ms.bfloat16, - **kwargs, - ): - super().__init__() - self.conv = nn.Conv3d( - input_nc, - ndf, - kernel_size=kernel_size, - stride=stride, - pad_mode=pad_mode, - padding=padding, - has_bias=has_bias, - **kwargs, - ).to_float(dtype) - self.dtype = dtype - - def construct(self, x): - if x.dtype == ms.float32: - return self.conv(x).to(ms.float32) - else: - return self.conv(x) - - -class NLayerDiscriminator(nn.Cell): - """Defines a PatchGAN discriminator as in Pix2Pix - --> refer to: https://github.com/junyanz/pyms-CycleGAN-and-pix2pix/blob/master/models/networks.py - """ - - def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.float32): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ndf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ - - # TODO: check forward consistency!!! - super().__init__() - self.dtype = dtype - if not use_actnorm: - norm_layer = nn.BatchNorm2d - else: - # norm_layer = ActNorm - raise NotImplementedError - if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters - use_bias = norm_layer.func != nn.BatchNorm2d - else: - use_bias = norm_layer != nn.BatchNorm2d - - kw = 4 - padw = 1 - # Fixed - sequence = [ - nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, pad_mode="pad", padding=padw, has_bias=True).to_float( - self.dtype - ), - nn.LeakyReLU(0.2), - ] - nf_mult = 1 - nf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - nf_mult_prev = nf_mult - nf_mult = min(2**n, 8) - sequence += [ - nn.Conv2d( - ndf * nf_mult_prev, - ndf * nf_mult, - kernel_size=kw, - stride=2, - pad_mode="pad", - padding=padw, - has_bias=use_bias, - ).to_float(self.dtype), - norm_layer(ndf * nf_mult, momentum=0.1), - nn.LeakyReLU(0.2), - ] - - nf_mult_prev = nf_mult - nf_mult = min(2**n_layers, 8) - sequence += [ - nn.Conv2d( - ndf * nf_mult_prev, - ndf * nf_mult, - kernel_size=kw, - stride=1, - pad_mode="pad", - padding=padw, - has_bias=use_bias, - ).to_float(self.dtype), - norm_layer(ndf * nf_mult, momentum=0.1), - nn.LeakyReLU(0.2), - ] - - # output 1 channel prediction map - sequence += [ - nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, pad_mode="pad", padding=padw, has_bias=True).to_float( - self.dtype - ) - ] - self.main = nn.SequentialCell(sequence) - self.cast = ops.Cast() - - def construct(self, x): - y = self.main(x) - return y - - class NLayerDiscriminator3D(nn.Cell): """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" @@ -159,8 +43,8 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f kw = 3 padw = 1 sequence = [ - Conv3d(input_nc, ndf, kernel_size=kw, stride=2, pad_mode="pad", padding=padw, has_bias=True), - nn.LeakyReLU(0.2), + nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, pad_mode="pad", padding=padw, has_bias=True), + nn.LeakyReLU(0.2).to_float(self.dtype), ] nf_mult = 1 nf_mult_prev = 1 @@ -178,7 +62,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f has_bias=use_bias, ), norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2), + nn.LeakyReLU(0.2).to_float(self.dtype), ] nf_mult_prev = nf_mult @@ -194,14 +78,22 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f has_bias=use_bias, ), norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2), + nn.LeakyReLU(0.2).to_float(self.dtype), ] sequence += [ Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, pad_mode="pad", has_bias=True) ] # output 1 channel prediction map - self.main = nn.SequentialCell(*sequence) + self.main = nn.CellList(sequence) def construct(self, input): """Standard forward.""" - return self.main(input) + x = input + for layer in self.main: + if isinstance(layer, nn.Conv3d): + x = npu_config.run_conv3d(layer, x, x.dtype) + elif isinstance(layer, nn.BatchNorm3d): + x = npu_config.run_batch_norm(layer, x) + else: + x = layer(x) + return x diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/lpips.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/lpips.py index 560a75146c..4c31a52a2e 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/lpips.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/lpips.py @@ -5,7 +5,7 @@ import mindspore as ms import mindspore.nn as nn -import mindspore.ops as ops +from mindspore import mint _logger = logging.getLogger(__name__) @@ -57,7 +57,7 @@ def construct(self, input, target): diff = (normalize_tensor(outs0[kk]) - normalize_tensor(outs1[kk])) ** 2 # res += spatial_average(lins[kk](diff), keepdim=True) # lin_layer = lins[kk] - val += ops.mean(lins[kk](diff), axis=[2, 3], keep_dims=True) + val += mint.mean(lins[kk](diff), dim=[2, 3], keepdim=True) return val @@ -138,7 +138,7 @@ def construct(self, X): def normalize_tensor(x, eps=1e-10): - norm_factor = ops.sqrt((x**2).sum(1, keepdims=True)) + norm_factor = mint.sqrt((x**2).sum(1, keepdims=True)) return x / (norm_factor + eps) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py index 82b267022a..3c0d67d400 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py @@ -1,5 +1,5 @@ import mindspore as ms -from mindspore import nn, ops +from mindspore import mint, nn, ops from .lpips import LPIPS @@ -20,19 +20,30 @@ def _rearrange_out(x, t): return x +def l1(x, y): + return mint.abs(x - y) + + +def l2(x, y): + return mint.pow((x - y), 2) + + class GeneratorWithLoss(nn.Cell): def __init__( self, autoencoder, disc_start=50001, - kl_weight=1.0e-06, + kl_weight=1.0, + perceptual_weight=1.0, + pixelloss_weight=1.0, disc_weight=0.5, disc_factor=1.0, - perceptual_weight=1.0, logvar_init=0.0, discriminator=None, dtype=ms.float32, lpips_ckpt_path=None, + learn_logvar: bool = False, + wavelet_weight=0.01, loss_type: str = "l1", ): super().__init__() @@ -50,10 +61,12 @@ def __init__( l2 = ops.L2Loss() self.loss_func = l1 if loss_type == "l1" else l2 # TODO: is self.logvar trainable? - self.logvar = ms.Parameter(ms.Tensor([logvar_init], dtype=dtype)) + self.logvar = ms.Parameter(ms.Tensor([logvar_init], dtype=dtype), requires_grad=learn_logvar) self.disc_start = disc_start + self.wavelet_weight = wavelet_weight self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight self.disc_weight = disc_weight self.disc_factor = disc_factor self.perceptual_weight = perceptual_weight @@ -66,14 +79,26 @@ def kl(self, mean, logvar): mean = mean.astype(ms.float32) logvar = logvar.astype(ms.float32) - var = ops.exp(logvar) - kl_loss = 0.5 * ops.sum( - ops.pow(mean, 2) + var - 1.0 - logvar, + var = mint.exp(logvar) + kl_loss = 0.5 * mint.sum( + mint.pow(mean, 2) + var - 1.0 - logvar, dim=[1, 2, 3], ) return kl_loss - def loss_function(self, x, recons, mean, logvar, global_step: ms.Tensor = -1, weights: ms.Tensor = None, cond=None): + def loss_function( + self, + x, + recons, + mean, + logvar, + global_step: ms.Tensor = -1, + weights: ms.Tensor = None, + cond=None, + wavelet_coeffs=None, + ): + bs = x.shape[0] + # For videos, treat them as independent frame images # TODO: regularize on temporal consistency t = x.shape[2] @@ -89,7 +114,7 @@ def loss_function(self, x, recons, mean, logvar, global_step: ms.Tensor = -1, we p_loss = self.perceptual_loss(x, recons) rec_loss = rec_loss + self.perceptual_weight * p_loss - nll_loss = rec_loss / ops.exp(self.logvar) + self.logvar + nll_loss = rec_loss / mint.exp(self.logvar) + self.logvar if weights is not None: weighted_nll_loss = weights * nll_loss mean_weighted_nll_loss = weighted_nll_loss.sum() / weighted_nll_loss.shape[0] @@ -100,9 +125,14 @@ def loss_function(self, x, recons, mean, logvar, global_step: ms.Tensor = -1, we # 2.3 kl loss kl_loss = self.kl(mean, logvar) - kl_loss = kl_loss.sum() / kl_loss.shape[0] - - loss = mean_weighted_nll_loss + self.kl_weight * kl_loss + kl_loss = kl_loss.sum() / bs + if wavelet_coeffs: + wl_loss_l2 = mint.sum(l1(wavelet_coeffs[0], wavelet_coeffs[1])) / bs + wl_loss_l3 = mint.sum(l1(wavelet_coeffs[2], wavelet_coeffs[3])) / bs + wl_loss = wl_loss_l2 + wl_loss_l3 + else: + wl_loss = 0 + loss = mean_weighted_nll_loss + self.kl_weight * kl_loss + self.wavelet_weight * wl_loss # 2.4 discriminator loss if enabled # g_loss = ms.Tensor(0., dtype=ms.float32) @@ -114,7 +144,7 @@ def loss_function(self, x, recons, mean, logvar, global_step: ms.Tensor = -1, we if cond is None: logits_fake = self.discriminator(recons) else: - logits_fake = self.discriminator(ops.concat((recons, cond), dim=1)) + logits_fake = self.discriminator(mint.cat((recons, cond), dim=1)) g_loss = -ops.reduce_mean(logits_fake) # TODO: do adaptive weighting based on grad # d_weight = self.calculate_adaptive_weight(mean_nll_loss, g_loss, last_layer=last_layer) @@ -129,9 +159,10 @@ def loss_function(self, x, recons, mean, logvar, global_step: ms.Tensor = -1, we "{}/kl_loss".format(split): kl_loss.asnumpy().mean(), "{}/nll_loss".format(split): nll_loss.asnumpy().mean(), "{}/rec_loss".format(split): rec_loss.asnumpy().mean(), + "{}/wl_loss".format(split): wl_loss.asnumpy().mean(), # "{}/d_weight".format(split): d_weight.detach(), # "{}/disc_factor".format(split): torch.tensor(disc_factor), - # "{}/g_loss".format(split): g_loss.detach().mean(), + # "{}/g_loss".format(split): g_loss.asnumpy().mean(), } for k in log: print(k.split("/")[1], log[k]) @@ -149,10 +180,16 @@ def construct(self, x: ms.Tensor, global_step: ms.Tensor = -1, weights: ms.Tenso """ # 1. AE forward, get posterior (mean, logvar) and recons - recons, mean, logvar = self.autoencoder(x) + outputs = self.autoencoder(x) + if len(outputs) == 3: + recons, mean, logvar = outputs + wavelet_coeffs = None + elif len(outputs) == 4: + # which means there is wavelet output + recons, mean, logvar, wavelet_coeffs = outputs # 2. compuate loss - loss = self.loss_function(x, recons, mean, logvar, global_step, weights, cond) + loss = self.loss_function(x, recons, mean, logvar, global_step, weights, cond, wavelet_coeffs) return loss @@ -190,17 +227,17 @@ def __init__( if disc_loss == "hinge": self.disc_loss = self.hinge_loss else: - self.softplus = ops.Softplus() + self.softplus = mint.nn.functional.softplus self.disc_loss = self.vanilla_d_loss def hinge_loss(self, logits_real, logits_fake): - loss_real = ops.mean(ops.relu(1.0 - logits_real)) - loss_fake = ops.mean(ops.relu(1.0 + logits_fake)) + loss_real = mint.mean(mint.nn.functional.relu(1.0 - logits_real)) + loss_fake = mint.mean(mint.nn.functional.relu(1.0 + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(self, logits_real, logits_fake): - d_loss = 0.5 * (ops.mean(self.softplus(-logits_real)) + ops.mean(self.softplus(logits_fake))) + d_loss = 0.5 * (mint.mean(self.softplus(-logits_real)) + mint.mean(self.softplus(logits_fake))) return d_loss def construct(self, x: ms.Tensor, global_step=-1, cond=None): @@ -225,8 +262,8 @@ def construct(self, x: ms.Tensor, global_step=-1, cond=None): logits_real = self.discriminator(x) logits_fake = self.discriminator(recons) else: - logits_real = self.discriminator(ops.concat((x, cond), dim=1)) - logits_fake = self.discriminator(ops.concat((recons, cond), dim=1)) + logits_real = self.discriminator(mint.cat((x, cond), dim=1)) + logits_fake = self.discriminator(mint.cat((recons, cond), dim=1)) if global_step >= self.disc_start: disc_factor = self.disc_factor diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/perceptual_loss.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/perceptual_loss.py index 770555aa97..303064860c 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/perceptual_loss.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/perceptual_loss.py @@ -1,26 +1,28 @@ import mindspore as ms -from mindspore import nn, ops +from mindspore import mint, nn, ops -from .discriminator import NLayerDiscriminator, NLayerDiscriminator3D, weights_init +from .discriminator import NLayerDiscriminator3D, weights_init from .lpips import LPIPS def hinge_d_loss(logits_real, logits_fake): - loss_real = ops.mean(ops.relu(1.0 - logits_real)) - loss_fake = ops.mean(ops.relu(1.0 + logits_fake)) + loss_real = mint.mean(mint.nn.functional.relu(1.0 - logits_real)) + loss_fake = mint.mean(mint.nn.functional.relu(1.0 + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(logits_real, logits_fake): - d_loss = 0.5 * (ops.mean(ops.softplus(-logits_real)) + ops.mean(ops.softplus(logits_fake))) + d_loss = 0.5 * ( + mint.mean(mint.nn.functional.softplus(-logits_real)) + mint.mean(mint.nn.functional.softplus(logits_fake)) + ) return d_loss def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] - loss_real = ops.mean(ops.relu(1.0 - logits_real), dim=[1, 2, 3]) - loss_fake = ops.mean(ops.relu(1.0 + logits_fake), dim=[1, 2, 3]) + loss_real = mint.mean(mint.nn.functional.relu(1.0 - logits_real), dim=[1, 2, 3]) + loss_fake = mint.mean(mint.nn.functional.relu(1.0 + logits_fake), dim=[1, 2, 3]) loss_real = (weights * loss_real).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum() d_loss = 0.5 * (loss_real + loss_fake) @@ -36,150 +38,19 @@ def adopt_weight(weight, global_step, threshold=0, value=0.0): def measure_perplexity(predicted_indices, n_embed): # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally - encodings = ops.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + encodings = mint.nn.functional.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) avg_probs = encodings.mean(0) - perplexity = (-(avg_probs * ops.log(avg_probs + 1e-10)).sum()).exp() - cluster_use = ops.sum(avg_probs > 0) + perplexity = (-(avg_probs * mint.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = mint.sum(avg_probs > 0) return perplexity, cluster_use def l1(x, y): - return ops.abs(x - y) + return mint.abs(x - y) def l2(x, y): - return ops.pow((x - y), 2) - - -class LPIPSWithDiscriminator(nn.Cell): - def __init__( - self, - disc_start, - logvar_init=0.0, - kl_weight=1.0, - pixelloss_weight=1.0, - perceptual_weight=1.0, - # --- Discriminator Loss --- - disc_num_layers=3, - disc_in_channels=3, - disc_factor=1.0, - disc_weight=1.0, - use_actnorm=False, - disc_conditional=False, - disc_loss="hinge", - ): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.kl_weight = kl_weight - self.pixel_weight = pixelloss_weight - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - self.logvar = ms.Parameter(ops.ones(shape=[]) * logvar_init) - - self.discriminator = NLayerDiscriminator( - input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm - ) # .apply(weights_init) - self.discriminator_iter_start = disc_start - self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = ms.grad(nll_loss, last_layer)[0] - g_grads = ms.grad(g_loss, last_layer)[0] - else: - nll_grads = ms.grad(nll_loss, self.last_layer[0])[0] - g_grads = ms.grad(g_loss, self.last_layer[0])[0] - - d_weight = ops.norm(nll_grads) / (ops.norm(g_grads) + 1e-4) - d_weight = ops.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def construct( - self, - inputs, - reconstructions, - posteriors, - optimizer_idx, - global_step, - split="train", - weights=None, - last_layer=None, - cond=None, - ): - # b c t h w -> (b t) c h w - b, c, t, h, w = inputs.shape - inputs = inputs.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w) - # b c t h w -> (b t) c h w - b, c, t, h, w = reconstructions.shape - reconstructions = reconstructions.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w) - rec_loss = ops.abs(inputs - reconstructions) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs, reconstructions) - rec_loss = rec_loss + self.perceptual_weight * p_loss - nll_loss = rec_loss / ops.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if weights is not None: - weighted_nll_loss = weights * nll_loss - weighted_nll_loss = ops.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = ops.sum(nll_loss) / nll_loss.shape[0] - kl_loss = posteriors.kl() - kl_loss = ops.sum(kl_loss) / kl_loss.shape[0] - - # GAN Part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions) - else: - assert self.disc_conditional - logits_fake = self.discriminator(ops.cat((reconstructions, cond), axis=1)) - g_loss = -ops.mean(logits_fake) - - if self.disc_factor > 0.0: - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - assert not self.training - d_weight = ms.Tensor(0.0) - else: - d_weight = ms.Tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss - log = { - "{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): ms.Tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - return loss, log - - if optimizer_idx == 1: - if cond is None: - logits_real = self.discriminator(inputs.detach()) - logits_fake = self.discriminator(reconstructions.detach()) - else: - logits_real = self.discriminator(ops.cat((inputs.detach(), cond), axis=1)) - logits_fake = self.discriminator(ops.cat((reconstructions.detach(), cond), axis=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = { - "{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean(), - } - return d_loss, log + return mint.pow((x - y), 2) class LPIPSWithDiscriminator3D(nn.Cell): @@ -191,7 +62,7 @@ def __init__( pixelloss_weight=1.0, perceptual_weight=1.0, # --- Discriminator Loss --- - disc_num_layers=3, + disc_num_layers=4, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, @@ -205,7 +76,7 @@ def __init__( self.pixel_weight = pixelloss_weight self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight - self.logvar = ms.Parameter(ops.ones(shape=[]) * logvar_init) + self.logvar = ms.Parameter(mint.ones(shape=[]) * logvar_init) self.discriminator = NLayerDiscriminator3D( input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm @@ -225,7 +96,7 @@ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): g_grads = ms.grad(g_loss, self.last_layer[0])[0] d_weight = ops.norm(nll_grads) / (ops.norm(g_grads) + 1e-4) - d_weight = ops.clamp(d_weight, 0.0, 1e4).detach() + d_weight = mint.clamp(d_weight, 0.0, 1e4) d_weight = d_weight * self.discriminator_weight return d_weight @@ -247,18 +118,18 @@ def construct( # b c t h w -> (b t) c h w b, c, t, h, w = reconstructions.shape reconstructions = reconstructions.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w) - rec_loss = ops.abs(inputs - reconstructions) + rec_loss = mint.abs(inputs - reconstructions) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(inputs, reconstructions) rec_loss = rec_loss + self.perceptual_weight * p_loss - nll_loss = rec_loss / ops.exp(self.logvar) + self.logvar + nll_loss = rec_loss / mint.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: weighted_nll_loss = weights * nll_loss - weighted_nll_loss = ops.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = ops.sum(nll_loss) / nll_loss.shape[0] + weighted_nll_loss = mint.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = mint.sum(nll_loss) / nll_loss.shape[0] kl_loss = posteriors.kl() - kl_loss = ops.sum(kl_loss) / kl_loss.shape[0] + kl_loss = mint.sum(kl_loss) / kl_loss.shape[0] # (b t) c h w -> b c t h w _, c, h, w = inputs.shape inputs = inputs.reshape(b, t, c, h, w).permute(0, 2, 1, 3, 4) @@ -272,8 +143,8 @@ def construct( logits_fake = self.discriminator(reconstructions) else: assert self.disc_conditional - logits_fake = self.discriminator(ops.cat((reconstructions, cond), axis=1)) - g_loss = -ops.mean(logits_fake) + logits_fake = self.discriminator(mint.cat((reconstructions, cond), dim=1)) + g_loss = -mint.mean(logits_fake) if self.disc_factor > 0.0: try: @@ -287,87 +158,31 @@ def construct( disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss # log = { - # "{}/total_loss".format(split): loss.clone().detach().mean(), - # "{}/logvar".format(split): self.logvar.detach(), - # "{}/kl_loss".format(split): kl_loss.detach().mean(), - # "{}/nll_loss".format(split): nll_loss.detach().mean(), - # "{}/rec_loss".format(split): rec_loss.detach().mean(), - # "{}/d_weight".format(split): d_weight.detach(), + # "{}/total_loss".format(split): loss.clone().mean(), + # "{}/logvar".format(split): self.logvar, + # "{}/kl_loss".format(split): kl_loss.mean(), + # "{}/nll_loss".format(split): nll_loss.mean(), + # "{}/rec_loss".format(split): rec_loss.mean(), + # "{}/d_weight".format(split): d_weight, # "{}/disc_factor".format(split): ms.Tensor(disc_factor), - # "{}/g_loss".format(split): g_loss.detach().mean(), + # "{}/g_loss".format(split): g_loss.mean(), # } return loss if optimizer_idx == 1: if cond is None: - logits_real = self.discriminator(inputs.detach()) - logits_fake = self.discriminator(reconstructions.detach()) + logits_real = self.discriminator(inputs) + logits_fake = self.discriminator(reconstructions) else: - logits_real = self.discriminator(ops.cat((inputs.detach(), cond), axis=1)) - logits_fake = self.discriminator(ops.cat((reconstructions.detach(), cond), axis=1)) + logits_real = self.discriminator(mint.cat((inputs, cond), dim=1)) + logits_fake = self.discriminator(mint.cat((reconstructions, cond), dim=1)) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) # log = { - # "{}/disc_loss".format(split): d_loss.clone().detach().mean(), - # "{}/logits_real".format(split): logits_real.detach().mean(), - # "{}/logits_fake".format(split): logits_fake.detach().mean(), + # "{}/disc_loss".format(split): d_loss.clone().mean(), + # "{}/logits_real".format(split): logits_real.mean(), + # "{}/logits_fake".format(split): logits_fake.mean(), # } return d_loss # , log - - -class SimpleLPIPS(nn.Cell): - def __init__( - self, - logvar_init=0.0, - kl_weight=1.0, - pixelloss_weight=1.0, - perceptual_weight=1.0, - disc_loss="hinge", - ): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.kl_weight = kl_weight - self.pixel_weight = pixelloss_weight - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - self.logvar = ms.Parameter(ops.ones(shape=()) * logvar_init) - - def construct( - self, - inputs, - reconstructions, - posteriors, - split="train", - weights=None, - ): - # b c t h w -> (b t) c h w - b, c, t, h, w = inputs.shape - inputs = inputs.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w) - # b c t h w -> (b t) c h w - b, c, t, h, w = reconstructions.shape - reconstructions = reconstructions.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w) - rec_loss = ops.abs(inputs - reconstructions) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs, reconstructions) - rec_loss = rec_loss + self.perceptual_weight * p_loss - nll_loss = rec_loss / ops.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if weights is not None: - weighted_nll_loss = weights * nll_loss - weighted_nll_loss = ops.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = ops.sum(nll_loss) / nll_loss.shape[0] - kl_loss = posteriors.kl() - kl_loss = ops.sum(kl_loss) / kl_loss.shape[0] - loss = weighted_nll_loss + self.kl_weight * kl_loss - log = { - "{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - } - if self.perceptual_weight > 0: - log.update({"{}/p_loss".format(split): p_loss.detach().mean()}) - return loss, log diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/__init__.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/__init__.py index 7fa25ab4d0..1282f19639 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/__init__.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/__init__.py @@ -1,17 +1,7 @@ -from .attention import AttnBlock, AttnBlock3D, AttnBlock3DFix # LinAttnBlock,; LinearAttention,; TemporalAttnBlock +from .attention import * from .block import Block -from .conv import CausalConv3d, Conv2d -from .normalize import GroupNormExtend, Normalize -from .resnet_block import ResnetBlock2D, ResnetBlock3D -from .updownsample import ( # TimeDownsampleResAdv2x,; TimeUpsampleResAdv2x - Downsample, - Spatial2xTime2x3DDownsample, - Spatial2xTime2x3DUpsample, - SpatialDownsample2x, - SpatialUpsample2x, - TimeDownsample2x, - TimeDownsampleRes2x, - TimeUpsample2x, - TimeUpsampleRes2x, - Upsample, -) +from .conv import * +from .normalize import * +from .resnet_block import * +from .updownsample import * +from .wavelet import * diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/attention.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/attention.py index 8f6ea72056..3803ba9d85 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/attention.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/attention.py @@ -1,82 +1,35 @@ import logging import mindspore as ms -from mindspore import nn, ops +from mindspore import mint, nn from .conv import CausalConv3d +from .normalize import Normalize -_logger = logging.getLogger(__name__) - - -class AttnBlock(nn.Cell): - def __init__(self, in_channels, dtype=ms.float32): - super().__init__() - self.in_channels = in_channels - self.dtype = dtype - self.bmm = ops.BatchMatMul() - self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True).to_float( - dtype - ) - self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True).to_float( - dtype - ) - self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True).to_float( - dtype - ) - self.proj_out = nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True - ).to_float(dtype) - self.softmax = nn.Softmax(axis=2) - - def construct(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) +try: + from opensora.npu_config import npu_config, set_run_dtype +except ImportError: + npu_config = None - # compute attention - b, c, h, w = q.shape - q = ops.reshape(q, (b, c, h * w)) - q = ops.transpose(q, (0, 2, 1)) # b,hw,c - k = ops.reshape(k, (b, c, h * w)) # b,c,hw - w_ = self.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - - w_ = w_ * (int(c) ** (-0.5)) - # FIXME: cast w_ to FP32 in amp - w_ = self.softmax(w_) - - # attend to values - v = ops.reshape(v, (b, c, h * w)) - w_ = ops.transpose(w_, (0, 2, 1)) # b,hw,hw (first hw of k, second of q) - h_ = self.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = ops.reshape(h_, (b, c, h, w)) - - h_ = self.proj_out(h_) +_logger = logging.getLogger(__name__) - return x + h_ +class AttnBlock3D(nn.Cell): + """Compatible with old versions, there are issues, use with caution.""" -class AttnBlock3DFix(nn.Cell): - def __init__(self, in_channels, dtype=ms.float32): + def __init__(self, in_channels): super().__init__() self.in_channels = in_channels - self.dtype = dtype - self.bmm = ops.BatchMatMul() - self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - # TODO: 1x1 conv3d can be replaced with flatten and Linear + self.norm = Normalize(in_channels) self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) - self.softmax = nn.Softmax(axis=2) + self.bmm = mint.bmm + self.softmax = mint.nn.Softmax(dim=2) def construct(self, x): - # q shape: (b c t h w) h_ = x h_ = self.norm(h_) q = self.q(h_) @@ -84,103 +37,72 @@ def construct(self, x): v = self.v(h_) # compute attention - # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c) b, c, t, h, w = q.shape - q = q.permute(0, 2, 1, 3, 4) - q = ops.reshape(q, (b * t, c, h * w)) + q = q.reshape(b * t, c, h * w) q = q.permute(0, 2, 1) # b,hw,c - - # k: (b c t h w) -> (b t c h w) -> (b*t c h*w) - k = k.permute(0, 2, 1, 3, 4) - k = ops.reshape(k, (b * t, c, h * w)) - - # w: (b*t hw hw) - # TODO: support Flash Attention + k = k.reshape(b * t, c, h * w) # b,c,hw w_ = self.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c) ** (-0.5)) - # FIXME: cast w_ to FP32 in amp w_ = self.softmax(w_) # attend to values - # v: (b c t h w) -> (b t c h w) -> (bt c hw) - # w_: (bt hw hw) -> (bt hw hw) - v = v.permute(0, 2, 1, 3, 4) - v = ops.reshape(v, (b * t, c, h * w)) - w_ = ops.transpose(w_, (0, 2, 1)) # b,hw,hw (first hw of k, second of q) + v = v.reshape(b * t, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) h_ = self.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - - # h_: (b*t c hw) -> (b t c h w) -> (b c t h w) - h_ = ops.reshape(h_, (b, t, c, h, w)) - h_ = h_.permute(0, 2, 1, 3, 4) + h_ = h_.reshape(b, c, t, h, w) h_ = self.proj_out(h_) return x + h_ -class AttnBlock3D(nn.Cell): - def __init__(self, in_channels, dtype=ms.float32): +class AttnBlock3DFix(nn.Cell): + """ + Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. + """ + + def __init__(self, in_channels, norm_type="groupnorm", dtype=ms.float32): super().__init__() self.in_channels = in_channels self.dtype = dtype - self.bmm = ops.BatchMatMul() - self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - # TODO: 1x1 conv3d can be replaced with flatten and Linear + self.norm = Normalize(in_channels, norm_type=norm_type) self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) - self.softmax = nn.Softmax(axis=2) def construct(self, x): - # q shape: (b c t h w) h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) - # compute attention - # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c) b, c, t, h, w = q.shape - q = q.permute(0, 2, 1, 3, 4) - q = ops.reshape(q, (b * t, c, h * w)) - q = q.permute(0, 2, 1) # b,hw,c - - # k: (b c t h w) -> (b t c h w) -> (b*t c h*w) - k = k.permute(0, 2, 1, 3, 4) - k = ops.reshape(k, (b * t, c, h * w)) - - # w: (b*t hw hw) - # TODO: support Flash Attention - w_ = self.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - # FIXME: cast w_ to FP32 in amp - w_ = self.softmax(w_) - - # attend to values - # v: (b c t h w) -> (b t c h w) -> (bt c hw) - # w_: (bt hw hw) -> (bt hw hw) - v = v.permute(0, 2, 1, 3, 4) - v = ops.reshape(v, (b * t, c, h * w)) - w_ = ops.transpose(w_, (0, 2, 1)) # b,hw,hw (first hw of k, second of q) - h_ = self.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - - # h_: (b*t c hw) -> (b t c h w) -> (b c t h w) - h_ = ops.reshape(h_, (b, t, c, h, w)) - h_ = h_.permute(0, 2, 1, 3, 4) - - h_ = self.proj_out(h_) + q = q.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c) + k = k.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c) + v = v.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c) + + if npu_config.enable_FA and q.dtype == ms.float32: + dtype = ms.bfloat16 + else: + dtype = None + with set_run_dtype(q, dtype): + query, key, value = npu_config.set_current_run_dtype([q, k, v]) + hidden_states = npu_config.run_attention( + query, + key, + value, + attention_mask=None, + input_layout="BSH", + head_dim=c // 2, + head_num=2, # FIXME: different from torch. To make head_dim 256 instead of 512 + ) + + attn_output = npu_config.restore_dtype(hidden_states) + + attn_output = attn_output.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3) + h_ = self.proj_out(attn_output) return x + h_ - - -def make_attn(in_channels, attn_type="vanilla", dtype=ms.float32): - assert attn_type in ["vanilla", "vanilla3D"], f"attn_type {attn_type} not supported" - _logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - return AttnBlock(in_channels, dtype=dtype) - elif attn_type == "vanilla3D": - return AttnBlock3D(in_channels, dtype=dtype) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py index 5bf579edd6..bbf1c4b360 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py @@ -1,14 +1,11 @@ -import logging from typing import Tuple, Union -import numpy as np +try: + from opensora.npu_config import npu_config +except ImportError: + npu_config = None -import mindspore as ms -from mindspore import nn, ops - -# from mindspore import mint - -_logger = logging.getLogger(__name__) +from mindspore import mint, nn, ops def divisible_by(num, den): @@ -31,7 +28,7 @@ class Conv2d(nn.Conv2d): def rearrange_in(self, x): # b c f h w -> b f c h w B, C, F, H, W = x.shape - x = ops.transpose(x, (0, 2, 1, 3, 4)) + x = mint.permute(x, (0, 2, 1, 3, 4)) # -> (b*f c h w) x = ops.reshape(x, (-1, C, H, W)) @@ -42,7 +39,7 @@ def rearrange_out(self, x, F): # (b*f D h w) -> (b f D h w) x = ops.reshape(x, (BF // F, F, D, H_, W_)) # -> (b D f h w) - x = ops.transpose(x, (0, 2, 1, 3, 4)) + x = mint.permute(x, (0, 2, 1, 3, 4)) return x @@ -60,101 +57,73 @@ def construct(self, x): class CausalConv3d(nn.Cell): - """ - Temporal padding: Padding with the first frame, by repeating K_t-1 times. - Spatial padding: follow standard conv3d, determined by pad mode and padding - Ref: opensora plan - - Args: - kernel_size: order (T, H, W) - stride: order (T, H, W) - padding: int, controls the amount of spatial padding applied to the input on both sides - """ - def __init__( self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], - padding: int = 0, - dtype=ms.bfloat16, + enable_cached=False, + bias=True, **kwargs, ): super().__init__() - assert isinstance(padding, int) - kernel_size = cast_tuple(kernel_size, 3) - time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - self.time_kernel_size = time_kernel_size - - assert is_odd(height_kernel_size) and is_odd(width_kernel_size) - - dilation = kwargs.pop("dilation", 1) - stride = kwargs.pop("stride", 1) - stride = cast_tuple(stride, 3) # (stride, 1, 1) - dilation = cast_tuple(dilation, 3) # (dilation, 1, 1) - - """ - if isinstance(padding, str): - if padding == 'same': - height_pad = height_kernel_size // 2 - width_pad = width_kernel_size // 2 - elif padding == 'valid': - height_pad = 0 - width_pad = 0 - else: - raise ValueError - else: - padding = list(cast_tuple(padding, 3)) - """ - # pad h,w dimensions if used, by conv3d API - # diff from torch: bias, pad_mode - - # TODO: why not use HeUniform init? - weight_init_value = 1.0 / (np.prod(kernel_size) * chan_in) - if padding == 0: + self.kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = self.kernel_size[0] + self.chan_in = chan_in + self.chan_out = chan_out + self.stride = kwargs.pop("stride", 1) + self.padding = kwargs.pop("padding", 0) + self.stride = cast_tuple(self.stride, 3) + if self.padding == 0: self.conv = nn.Conv3d( - chan_in, - chan_out, - kernel_size, - stride=stride, - dilation=dilation, - has_bias=True, - pad_mode="valid", - weight_init=weight_init_value, - bias_init="zeros", - **kwargs, - ).to_float(dtype) + chan_in, chan_out, self.kernel_size, stride=self.stride, pad_mode="valid", has_bias=bias, **kwargs + ) else: - # axis order (t0, t1, h0 ,h1, w0, w2) - padding = list(cast_tuple(padding, 6)) - padding[0] = 0 - padding[1] = 0 - padding = tuple(padding) + self.padding = list(cast_tuple(self.padding, 6)) + self.padding[0] = 0 + self.padding[1] = 0 + self.conv = nn.Conv3d( chan_in, chan_out, - kernel_size, - stride=stride, - dilation=dilation, - has_bias=True, + self.kernel_size, + stride=self.stride, + padding=tuple(self.padding), pad_mode="pad", - padding=padding, - weight_init=weight_init_value, - bias_init="zeros", + has_bias=bias, **kwargs, - ).to_float(dtype) - self.dtype = dtype + ) + self.enable_cached = enable_cached + self.causal_cached = None + self.cache_offset = 0 def construct(self, x): + x_dtype = x.dtype # x: (bs, Cin, T, H, W ) # first_frame_pad = ops.repeat_interleave(first_frame, (self.time_kernel_size - 1), axis=2) if self.time_kernel_size - 1 > 0: - first_frame = x[:, :, :1, :, :] - first_frame_pad = ops.cat([first_frame] * (self.time_kernel_size - 1), axis=2) - # first_frame_pad = mint.repeat_interleave([first_frame], self.time_kernel_size - 1, 2) - x = ops.concat((first_frame_pad, x), axis=2) + if self.causal_cached is None: + first_frame = x[:, :, :1, :, :] + first_frame_pad = mint.cat([first_frame] * (self.time_kernel_size - 1), dim=2) + # first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + else: + first_frame_pad = self.causal_cached + + x = mint.cat((first_frame_pad, x), dim=2) + + if self.enable_cached and self.time_kernel_size != 1: + if (self.time_kernel_size - 1) // self.stride[0] != 0: + if self.cache_offset == 0: + self.causal_cached = x[:, :, -(self.time_kernel_size - 1) // self.stride[0] :] + else: + self.causal_cached = x[:, :, : -self.cache_offset][ + :, :, -(self.time_kernel_size - 1) // self.stride[0] : + ] + else: + self.causal_cached = x[:, :, 0:0, :, :] - if x.dtype == ms.float32: - return self.conv(x).to(ms.float32) + if npu_config is not None and npu_config.on_npu: + return npu_config.run_conv3d(self.conv, x, x_dtype) else: - return self.conv(x) + x = self.conv(x) + return x diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/normalize.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/normalize.py index fe3cd5a9c0..cfa961ca4d 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/normalize.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/normalize.py @@ -1,5 +1,6 @@ -import mindspore as ms -from mindspore import Parameter, nn, ops +from mindspore import nn + +from mindone.diffusers.models.normalization import LayerNorm as LayerNorm_diffusers # TODO: put them to modules/normalize.py @@ -13,79 +14,29 @@ def construct(self, x): return y.view(x_shape) -def Normalize(in_channels, num_groups=32, extend=True): - if extend: - return GroupNormExtend(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - else: - return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - -class ActNorm(nn.Cell): - def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): - assert affine - super().__init__() - self.logdet = logdet - self.loc = Parameter(ops.zeros(1, num_features, 1, 1)) - self.scale = Parameter(ops.ones(1, num_features, 1, 1)) - self.allow_reverse_init = allow_reverse_init - - self.initialized = Parameter(ops.zeros(1, dtype=ms.uint8), requires_grad=False, name="initialized") - - def initialize(self, input): - flatten = input.permute(1, 0, 2, 3).view(input.shape[1], -1) - mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) - std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) - - self.loc.set_data(-mean) - self.scale.set_data(1 / (std + 1e-6)) - - def construct(self, input, reverse=False): - if reverse: - return self.reverse(input) - if len(input.shape) == 2: - input = input[:, :, None, None] - squeeze = True - else: - squeeze = False - - _, _, height, width = input.shape - - if self.training and self.initialized[0] == 0: - self.initialize(input) # stop_grads? - self.initialized.set_data(ops.ones(1, dtype=ms.uint8)) - - h = self.scale * (input + self.loc) - - if squeeze: - h = h.squeeze(-1).squeeze(-1) - - if self.logdet: - log_abs = ops.log(ops.abs(self.scale)) - logdet = height * width * ops.sum(log_abs) - logdet = logdet * ops.ones(input.shape[0]).to(input.dtype) - return h, logdet +class LayerNorm(nn.Cell): + def __init__(self, num_channels, eps=1e-6, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.norm = LayerNorm_diffusers(num_channels, eps=eps, elementwise_affine=True) - return h - - def reverse(self, output): - if self.training and self.initialized.item() == 0: - if not self.allow_reverse_init: - raise RuntimeError( - "Initializing ActNorm in reverse direction is " - "disabled by default. Use allow_reverse_init=True to enable." - ) - else: - self.initialize(output) - self.initialized.fill_(1) - - if len(output.shape) == 2: - output = output[:, :, None, None] - squeeze = True + def construct(self, x): + if x.ndim == 5: + # b c t h w -> b t h w c + x = x.transpose(0, 2, 3, 4, 1) + x = self.norm(x) + # b t h w c -> b c t h w + x = x.transpose(0, 4, 1, 2, 3) else: - squeeze = False + # b c h w -> b h w c + x = x.transpose(0, 2, 3, 1) + x = self.norm(x) + # b h w c -> b c h w + x = x.transpose(0, 3, 1, 2) + return x - h = output / self.scale - self.loc - if squeeze: - h = h.squeeze(-1).squeeze(-1) - return h +def Normalize(in_channels, num_groups=32, norm_type="groupnorm"): + if norm_type == "groupnorm": + return GroupNormExtend(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "layernorm": + return LayerNorm(num_channels=in_channels, eps=1e-6) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/ops.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/ops.py index 1e762d934e..0ec15ccec0 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/ops.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/ops.py @@ -1,15 +1,38 @@ import mindspore as ms -from mindspore import nn +from mindspore import mint + + +def video_to_image(func): + def wrapper(self, x, *args, **kwargs): + if x.ndim == 5: + b, c, t, h, w = x.shape + if True: + # b c t h w -> (b t) c h w + x = x.swapaxes(1, 2).reshape(-1, c, h, w) # (b*t, c, h, w) + x = func(self, x, *args, **kwargs) + x = x.reshape(x.shape[0] // t, t, x.shape[1], x.shape[2], x.shape[3]) # (b, t, c, h, w) + x = x.transpose(0, 2, 1, 3, 4) # (b, c, t, h, w) + else: + # Conv 2d slice infer + result = [] + for i in range(t): + frame = x[:, :, i, :, :] + frame = func(self, frame, *args, **kwargs) + result.append(frame.unsqueeze(2)) + x = mint.cat(result, dim=2) + return x + + return wrapper def nonlinearity(x, upcast=False): # swish ori_dtype = x.dtype - # if upcast: - # return x * (ops.sigmoid(x.astype(ms.float32))).astype(ori_dtype) - # else: - # return x * (ops.sigmoid(x)) - return nn.SiLU()(x.astype(ms.float32) if upcast else x).to(ori_dtype) + if upcast: + return x * (mint.sigmoid(x.astype(ms.float32))).astype(ori_dtype) + else: + return x * (mint.sigmoid(x)) + # return nn.SiLU()(x.astype(ms.float32) if upcast else x).to(ori_dtype) def divisible_by(num, den): diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py index 729d233b44..eb6a4c697d 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py @@ -1,12 +1,16 @@ import mindspore as ms -from mindspore import nn, ops +from mindspore import nn +try: + from opensora.npu_config import npu_config +except ImportError: + npu_config = None from .conv import CausalConv3d -from .ops import nonlinearity +from .normalize import Normalize +from .ops import nonlinearity, video_to_image -# used in vae -class ResnetBlock(nn.Cell): +class ResnetBlock3D(nn.Cell): def __init__( self, *, @@ -14,76 +18,21 @@ def __init__( out_channels=None, conv_shortcut=False, dropout, - temb_channels=512, + norm_type, dtype=ms.float32, upcast_sigmoid=False, ): super().__init__() self.dtype = dtype self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.upcast_sigmoid = upcast_sigmoid - - self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - self.conv1 = nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True - ).to_float(dtype) - if temb_channels > 0: - self.temb_proj = nn.Dense(temb_channels, out_channels, bias_init="normal").to_float(dtype) - self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) - self.dropout = nn.Dropout(p=dropout) - self.conv2 = nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True - ).to_float(dtype) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True - ).to_float(dtype) - else: - self.nin_shortcut = nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True - ).to_float(dtype) - - def construct(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h, upcast=self.upcast_sigmoid) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb, upcast=self.upcast_sigmoid))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h, upcast=self.upcast_sigmoid) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class ResnetBlock3D(nn.Cell): - def __init__( - self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, dtype=ms.float32, upcast_sigmoid=False - ): - super().__init__() - self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut self.upcast_sigmoid = upcast_sigmoid # FIXME: GroupNorm precision mismatch with PT. - self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.norm1 = Normalize(in_channels, norm_type=norm_type) self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) - self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.norm2 = Normalize(out_channels, norm_type=norm_type) self.dropout = nn.Dropout(p=dropout) self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) if self.in_channels != self.out_channels: @@ -94,10 +43,10 @@ def __init__( def construct(self, x): h = x - h = self.norm1(h) + h = npu_config.run_group_norm(self.norm1, h) h = nonlinearity(h, self.upcast_sigmoid) h = self.conv1(h) - h = self.norm2(h) + h = npu_config.run_group_norm(self.norm2, h) h = nonlinearity(h, self.upcast_sigmoid) h = self.dropout(h) h = self.conv2(h) @@ -117,6 +66,7 @@ def __init__( in_channels, out_channels=None, conv_shortcut=False, + norm_type, dropout, dtype=ms.float32, upcast_sigmoid=False, @@ -129,11 +79,11 @@ def __init__( self.use_conv_shortcut = conv_shortcut self.upcast_sigmoid = upcast_sigmoid - self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.norm1 = Normalize(in_channels, norm_type=norm_type) self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True ).to_float(dtype) - self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.norm2 = Normalize(out_channels, norm_type=norm_type) self.dropout = nn.Dropout(p=dropout) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True @@ -148,48 +98,22 @@ def __init__( in_channels, out_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True ).to_float(dtype) - def rearrange_in(self, x): - # b c f h w -> b f c h w - B, C, F, H, W = x.shape - x = ops.transpose(x, (0, 2, 1, 3, 4)) - # -> (b*f c h w) - x = ops.reshape(x, (-1, C, H, W)) - - return x - - def rearrange_out(self, x, F): - BF, D, H_, W_ = x.shape - # (b*f D h w) -> (b f D h w) - x = ops.reshape(x, (BF // F, F, D, H_, W_)) - # -> (b D f h w) - x = ops.transpose(x, (0, 2, 1, 3, 4)) - - return x - + @video_to_image def construct(self, x): - # import pdb; pdb.set_trace() - # x: (b c f h w) - # rearrange in - F = x.shape[-3] - x = self.rearrange_in(x) - h = x - h = self.norm1(h) - h = nonlinearity(h, upcast=self.upcast_sigmoid) + + h = npu_config.run_group_norm(self.norm1, h) + h = nonlinearity(h, self.upcast_sigmoid) h = self.conv1(h) - h = self.norm2(h) - h = nonlinearity(h, upcast=self.upcast_sigmoid) + h = npu_config.run_group_norm(self.norm2, h) + h = nonlinearity(h, self.upcast_sigmoid) h = self.dropout(h) h = self.conv2(h) - if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) - x = x + h - # rearrange out - x = self.rearrange_out(x, F) return x diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py index c1887d11d6..0f7f12043d 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py @@ -1,76 +1,64 @@ from typing import Tuple, Union +from opensora.npu_config import npu_config + import mindspore as ms -from mindspore import nn, ops +from mindspore import mint, nn, ops from .conv import CausalConv3d -from .ops import cast_tuple +from .ops import cast_tuple, video_to_image class Upsample(nn.Cell): - def __init__(self, in_channels, with_conv, dtype=ms.float32): + def __init__(self, in_channels, out_channels, with_conv=True, dtype=ms.float32): super().__init__() self.dtype = dtype self.with_conv = with_conv if self.with_conv: self.conv = nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True ).to_float(self.dtype) + @video_to_image def construct(self, x): in_shape = x.shape[-2:] out_shape = tuple(2 * x for x in in_shape) x = ops.ResizeNearestNeighbor(out_shape)(x) - if self.with_conv: x = self.conv(x) return x class Downsample(nn.Cell): - def __init__(self, in_channels, with_conv=True, dtype=ms.float32): + def __init__(self, in_channels, out_channels, undown=False, dtype=ms.float32): super().__init__() + self.with_conv = True + self.undown = undown self.dtype = dtype - self.with_conv = with_conv - assert with_conv, "Downsample is forced to use conv in opensora v1.1" if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, pad_mode="valid", padding=0, has_bias=True - ).to_float(self.dtype) - - def rearrange_in(self, x): - # b c f h w -> b f c h w - B, C, F, H, W = x.shape - x = ops.transpose(x, (0, 2, 1, 3, 4)) - # -> (b*f c h w) - x = ops.reshape(x, (-1, C, H, W)) - - return x - - def rearrange_out(self, x, F): - BF, D, H_, W_ = x.shape - # (b*f D h w) -> (b f D h w) - x = ops.reshape(x, (BF // F, F, D, H_, W_)) - # -> (b D f h w) - x = ops.transpose(x, (0, 2, 1, 3, 4)) - - return x + if self.undown: + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode="pad", has_bias=True + ).to_float(self.dtype) + else: + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, padding=0, pad_mode="pad", has_bias=True + ).to_float(self.dtype) + @video_to_image def construct(self, x): - F = x.shape[-3] - x = self.rearrange_in(x) - if self.with_conv: - pad = ((0, 0), (0, 0), (0, 1), (0, 1)) - x = nn.Pad(paddings=pad)(x) - # pad = (0, 1, 0, 1) # (pad_left, pad_right, pad_top, pad_bottom) - # x = ops.pad(x, pad, mode="constant", value=0) - x = self.conv(x) + if self.undown: + x = self.conv(x) + else: + pad = (0, 1, 0, 1) + x = mint.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) else: - x = ops.AvgPool(kernel_size=2, stride=2)(x) - - x = self.rearrange_out(x, F) + x = npu_config.run_pool_2d( + mint.nn.functional.avg_pool2d, kernel_size=2, stride=2 + ) # avgpool does not support bf16, but only fp32 and fp16 return x @@ -105,8 +93,8 @@ def __init__( def construct(self, x): # x shape: (b c t h w) - # x = ops.pad(x, self.padding, mode="constant", value=0) - x = self.pad(x) + pad = (0, 1, 0, 1, 0, 0) + x = mint.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x @@ -170,8 +158,8 @@ def __init__( def construct(self, x): first_frame = x[:, :, :1, :, :] # first_frame_pad = ops.repeat_interleave(first_frame, self.time_pad, axis=2) - first_frame_pad = ops.cat([first_frame] * self.time_pad, axis=2) - x = ops.concat((first_frame_pad, x), axis=2) + first_frame_pad = mint.cat([first_frame] * self.time_pad, dim=2) + x = mint.cat((first_frame_pad, x), dim=2) if not self.replace_avgpool3d: return self.conv(x) @@ -195,7 +183,7 @@ def construct(self, x): x, x_ = x[:, :, :1], x[:, :, 1:] # FIXME: ms2.2.10 cannot support trilinear on 910b x_ = ops.interpolate(x_, scale_factor=(2.0, 1.0, 1.0), mode="trilinear") - x = ops.concat([x, x_], axis=2) + x = mint.cat([x, x_], dim=2) else: x = ops.interpolate(x, scale_factor=(2.0, 1.0, 1.0), mode="trilinear") @@ -233,14 +221,17 @@ def __init__( self.mix_factor = ms.Parameter(ms.Tensor([mix_factor]), requires_grad=True) def construct(self, x): - alpha = ops.sigmoid(self.mix_factor) + alpha = mint.sigmoid(self.mix_factor) first_frame = x[:, :, :1, :, :] # first_frame_pad = ops.repeat_interleave(first_frame, self.time_pad, axis=2) - first_frame_pad = ops.cat([first_frame] * self.time_pad, axis=2) - x = ops.concat((first_frame_pad, x), axis=2) - - conv_out = self.conv(x) + first_frame_pad = mint.cat([first_frame] * self.time_pad, dim=2) + x = mint.cat((first_frame_pad, x), dim=2) + if npu_config is not None and npu_config.on_npu: + x_dtype = x.dtype + conv_out = npu_config.run_conv3d(self.conv, x, x_dtype) + else: + conv_out = self.conv(x) # avg pool if not self.replace_avgpool3d: @@ -266,16 +257,16 @@ def __init__( super().__init__() self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1) self.mix_factor = ms.Parameter(ms.Tensor([mix_factor]), requires_grad=True) - self.intepolate = TrilinearInterpolate() + self.interpolate = TrilinearInterpolate() def construct(self, x): - alpha = ops.sigmoid(self.mix_factor) + alpha = mint.sigmoid(self.mix_factor) if x.shape[2] > 1: x, x_ = x[:, :, :1], x[:, :, 1:] ori_dtype = x.dtype # FIXME: ms2.2.10 cannot support trilinear on 910b - x_ = self.intepolate(x_, scale_factor=(2.0, 1.0, 1.0)).to(ori_dtype) - x = ops.concat([x, x_], axis=2) + x_ = self.interpolate(x_, scale_factor=(2.0, 1.0, 1.0)).to(ori_dtype) + x = mint.cat([x, x_], dim=2) return alpha * x + (1 - alpha) * self.conv(x) @@ -286,20 +277,43 @@ def construct(self, x, scale_factor): class Spatial2xTime2x3DUpsample(nn.Cell): - def __init__(self, in_channels, out_channels, dtype=ms.float32): + def __init__( + self, + in_channels, + out_channels, + dtype=ms.float32, + t_interpolation="trilinear", + enable_cached=False, + ): super().__init__() self.dtype = dtype + self.t_interpolation = t_interpolation + assert self.t_interpolation == "trilinear", "only support trilinear interpolation now" self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1) - self.intepolate = TrilinearInterpolate() + self.interpolate = TrilinearInterpolate() + self.enable_cached = enable_cached + self.causal_cached = None def construct(self, x): - if x.shape[2] > 1: - x, x_ = x[:, :, :1], x[:, :, 1:] - x_ = self.intepolate(x_, scale_factor=(2.0, 2.0, 2.0)) - x = self.intepolate(x, scale_factor=(1.0, 2.0, 2.0)) - x = ops.cat([x, x_], axis=2) + if x.shape[2] > 1 or self.causal_cached is not None: + if self.enable_cached and self.causal_cached is not None: + x = mint.cat([self.causal_cached, x], dim=2) + self.causal_cached = x[:, :, -2:-1] + x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(2.0, 1.0, 1.0)) + x = x[:, :, 2:] + x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(1.0, 2.0, 2.0)) + else: + if self.enable_cached: + self.causal_cached = x[:, :, -1:] + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = npu_config.run_interpolate(self.interpolate, x_, scale_factor=(2.0, 1.0, 1.0)) + x_ = npu_config.run_interpolate(self.interpolate, x_, scale_factor=(1.0, 2.0, 2.0)) + x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(1.0, 2.0, 2.0)) + x = mint.cat([x, x_], dim=2) else: - x = self.intepolate(x, scale_factor=(1.0, 2.0, 2.0)) + if self.enable_cached: + self.causal_cached = x[:, :, -1:] + x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(1.0, 2.0, 2.0)) return self.conv(x) @@ -308,9 +322,9 @@ def __init__(self, in_channels, out_channels, dtype=ms.float32): super().__init__() self.dtype = dtype self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=2) - self.pad = ops.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 1), (0, 1))) def construct(self, x): - x = self.pad(x) + pad = (0, 1, 0, 1, 0, 0) + x = mint.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py new file mode 100644 index 0000000000..f7c52258e4 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py @@ -0,0 +1,259 @@ +import mindspore as ms +from mindspore import Tensor, mint, nn, ops + +from ..modules import CausalConv3d +from ..modules.ops import video_to_image + +try: + from opensora.npu_config import npu_config +except ImportError: + npu_config = None + + +class HaarWaveletTransform3D(nn.Cell): + def __init__(self, dtype=ms.float32) -> None: + super().__init__() + self.dtype = dtype + h = Tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) * 0.3536 + g = Tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]) * 0.3536 + hh = Tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]) * 0.3536 + gh = Tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]) * 0.3536 + h_v = Tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]) * 0.3536 + g_v = Tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]) * 0.3536 + hh_v = Tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]) * 0.3536 + gh_v = Tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) * 0.3536 + h = h.view(1, 1, 2, 2, 2) + g = g.view(1, 1, 2, 2, 2) + hh = hh.view(1, 1, 2, 2, 2) + gh = gh.view(1, 1, 2, 2, 2) + h_v = h_v.view(1, 1, 2, 2, 2) + g_v = g_v.view(1, 1, 2, 2, 2) + hh_v = hh_v.view(1, 1, 2, 2, 2) + gh_v = gh_v.view(1, 1, 2, 2, 2) + + self.h_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.g_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.hh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.gh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.h_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.g_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.hh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.gh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + + self.h_conv.conv.weight.set_data(h) + self.g_conv.conv.weight.set_data(g) + self.hh_conv.conv.weight.set_data(hh) + self.gh_conv.conv.weight.set_data(gh) + self.h_v_conv.conv.weight.set_data(h_v) + self.g_v_conv.conv.weight.set_data(g_v) + self.hh_v_conv.conv.weight.set_data(hh_v) + self.gh_v_conv.conv.weight.set_data(gh_v) + self.h_conv.requires_grad = False + self.g_conv.requires_grad = False + self.hh_conv.requires_grad = False + self.gh_conv.requires_grad = False + self.h_v_conv.requires_grad = False + self.g_v_conv.requires_grad = False + self.hh_v_conv.requires_grad = False + self.gh_v_conv.requires_grad = False + + def construct(self, x): + assert x.ndim == 5 + + b = x.shape[0] + # b c t h w -> (b c) 1 t h w + x = x.reshape(-1, 1, *x.shape[-3:]) + low_low_low = self.h_conv(x) + low_low_low = low_low_low.reshape( + b, low_low_low.shape[0] // b, *low_low_low.shape[-3:] + ) # (b c) 1 t h w -> b c t h w + low_low_high = self.g_conv(x) + low_low_high = low_low_high.reshape( + b, low_low_high.shape[0] // b, *low_low_high.shape[-3:] + ) # (b c) 1 t h w -> b c t h w + low_high_low = self.hh_conv(x) + low_high_low = low_high_low.reshape( + b, low_high_low.shape[0] // b, *low_high_low.shape[-3:] + ) # (b c) 1 t h w -> b c t h w + low_high_high = self.gh_conv(x) + low_high_high = low_high_high.reshape( + b, low_high_high.shape[0] // b, *low_high_high.shape[-3:] + ) # (b c) 1 t h w -> b c t h w + high_low_low = self.h_v_conv(x) + high_low_low = high_low_low.reshape( + b, high_low_low.shape[0] // b, *high_low_low.shape[-3:] + ) # (b c) 1 t h w -> b c t h w + high_low_high = self.g_v_conv(x) + high_low_high = high_low_high.reshape( + b, high_low_high.shape[0] // b, *high_low_high.shape[-3:] + ) # (b c) 1 t h w -> b c t h w + high_high_low = self.hh_v_conv(x) + high_high_low = high_high_low.reshape( + b, high_high_low.shape[0] // b, *high_high_low.shape[-3:] + ) # (b c) 1 t h w -> b c t h w + high_high_high = self.gh_v_conv(x) + high_high_high = high_high_high.reshape( + b, high_high_high.shape[0] // b, *high_high_high.shape[-3:] + ) # (b c) 1 t h w -> b c t h w + + output = mint.cat( + [ + low_low_low, + low_low_high, + low_high_low, + low_high_high, + high_low_low, + high_low_high, + high_high_low, + high_high_high, + ], + dim=1, + ) + + return output + + +class InverseHaarWaveletTransform3D(nn.Cell): + def __init__(self, enable_cached=False, dtype=ms.float16, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.dtype = dtype + + if self.dtype == ms.float32 or self.dtype == ms.bfloat16: + self.dtype = ms.float16 + dtype = ms.float16 + print("conv3d transpose layer is forced to fp16") + + self.h = Tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 + self.g = Tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 + self.hh = Tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 + self.gh = Tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 + self.h_v = Tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 + self.g_v = Tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 + self.hh_v = Tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 + self.gh_v = Tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 + self.enable_cached = enable_cached + self.causal_cached = None + self.conv_transpose3d = ops.Conv3DTranspose(1, 1, kernel_size=2, stride=2) + + def construct(self, coeffs): + assert coeffs.ndim == 5 + input_dtype = coeffs.dtype + coeffs = coeffs.to(self.dtype) + b = coeffs.shape[0] + + ( + low_low_low, + low_low_high, + low_high_low, + low_high_high, + high_low_low, + high_low_high, + high_high_low, + high_high_high, + ) = mint.chunk(coeffs, 8, dim=1) + + low_low_low = low_low_low.reshape(-1, 1, *low_low_low.shape[-3:]) + low_low_high = low_low_high.reshape(-1, 1, *low_low_high.shape[-3:]) + low_high_low = low_high_low.reshape(-1, 1, *low_high_low.shape[-3:]) + low_high_high = low_high_high.reshape(-1, 1, *low_high_high.shape[-3:]) + high_low_low = high_low_low.reshape(-1, 1, *high_low_low.shape[-3:]) + high_low_high = high_low_high.reshape(-1, 1, *high_low_high.shape[-3:]) + high_high_low = high_high_low.reshape(-1, 1, *high_high_low.shape[-3:]) + high_high_high = high_high_high.reshape(-1, 1, *high_high_high.shape[-3:]) + + low_low_low = self.conv_transpose3d(low_low_low, self.h) + low_low_high = self.conv_transpose3d(low_low_high, self.g) + low_high_low = self.conv_transpose3d(low_high_low, self.hh) + low_high_high = self.conv_transpose3d(low_high_high, self.gh) + high_low_low = self.conv_transpose3d(high_low_low, self.h_v) + high_low_high = self.conv_transpose3d(high_low_high, self.g_v) + high_high_low = self.conv_transpose3d(high_high_low, self.hh_v) + high_high_high = self.conv_transpose3d(high_high_high, self.gh_v) + if self.enable_cached and self.causal_cached: + reconstructed = ( + low_low_low + + low_low_high + + low_high_low + + low_high_high + + high_low_low + + high_low_high + + high_high_low + + high_high_high + ) + else: + reconstructed = ( + low_low_low[:, :, 1:] + + low_low_high[:, :, 1:] + + low_high_low[:, :, 1:] + + low_high_high[:, :, 1:] + + high_low_low[:, :, 1:] + + high_low_high[:, :, 1:] + + high_high_low[:, :, 1:] + + high_high_high[:, :, 1:] + ) + self.causal_cached = True + reconstructed = reconstructed.reshape(b, -1, *reconstructed.shape[-3:]) + + return reconstructed.to(input_dtype) + + +class HaarWaveletTransform2D(nn.Cell): + def __init__(self, dtype=ms.float32): + super().__init__() + self.dtype = dtype + self.aa = Tensor([[1, 1], [1, 1]], dtype=dtype).view(1, 1, 2, 2) / 2 + self.ad = Tensor([[1, 1], [-1, -1]], dtype=dtype).view(1, 1, 2, 2) / 2 + self.da = Tensor([[1, -1], [1, -1]], dtype=dtype).view(1, 1, 2, 2) / 2 + self.dd = Tensor([[1, -1], [-1, 1]], dtype=dtype).view(1, 1, 2, 2) / 2 + + @video_to_image + def construct(self, x): + b, c, h, w = x.shape + x = x.reshape(b * c, 1, h, w) + x_dtype = x.dtype + if x.dtype != self.dtype: + x = x.to(self.dtype) + + low_low = ops.conv2d(x, self.aa, stride=2).reshape(b, c, h // 2, w // 2) + low_high = ops.conv2d(x, self.ad, stride=2).reshape(b, c, h // 2, w // 2) + high_low = ops.conv2d(x, self.da, stride=2).reshape(b, c, h // 2, w // 2) + high_high = ops.conv2d(x, self.dd, stride=2).reshape(b, c, h // 2, w // 2) + coeffs = mint.cat([low_low, low_high, high_low, high_high], dim=1) + return coeffs.to(x_dtype) + + +class InverseHaarWaveletTransform2D(nn.Cell): + def __init__(self, dtype=ms.float32): + super().__init__() + aa = Tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2 + ad = Tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2 + da = Tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2 + dd = Tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2 + self.dtype = dtype + self.aa = nn.Conv2dTranspose(1, 1, kernel_size=2, stride=2, has_bias=False).to_float(dtype) + self.ad = nn.Conv2dTranspose(1, 1, kernel_size=2, stride=2, has_bias=False).to_float(dtype) + self.da = nn.Conv2dTranspose(1, 1, kernel_size=2, stride=2, has_bias=False).to_float(dtype) + self.dd = nn.Conv2dTranspose(1, 1, kernel_size=2, stride=2, has_bias=False).to_float(dtype) + self.aa.weight.set_data(aa) + self.aa.requires_grad = False + self.ad.weight.set_data(ad) + self.ad.requires_grad = False + self.da.weight.set_data(da) + self.da.requires_grad = False + self.dd.weight.set_data(dd) + self.dd.requires_grad = False + + @video_to_image + def construct(self, coeffs): + low_low, low_high, high_low, high_high = mint.chunk(coeffs, 4, dim=1) + b, c, height_half, width_half = low_low.shape + height = height_half * 2 + width = width_half * 2 + + low_low = self.aa(low_low.reshape(b * c, 1, height_half, width_half)) + low_high = self.ad(low_high.reshape(b * c, 1, height_half, width_half)) + high_low = self.da(high_low.reshape(b * c, 1, height_half, width_half)) + high_high = self.dd(high_high.reshape(b * c, 1, height_half, width_half)) + + return (low_low + low_high + high_low + high_high).reshape(b, c, height, width) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/registry.py b/examples/opensora_pku/opensora/models/causalvideovae/model/registry.py new file mode 100644 index 0000000000..9c6e86e423 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/registry.py @@ -0,0 +1,14 @@ +class ModelRegistry: + _models = {} + + @classmethod + def register(cls, model_name): + def decorator(model_class): + cls._models[model_name] = model_class + return model_class + + return decorator + + @classmethod + def get_model(cls, model_name): + return cls._models.get(model_name) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/utils/distrib_utils.py b/examples/opensora_pku/opensora/models/causalvideovae/model/utils/distrib_utils.py new file mode 100644 index 0000000000..ed5178e8a6 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/utils/distrib_utils.py @@ -0,0 +1,43 @@ +import mindspore as ms +from mindspore import Tensor, mint + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.mean, self.logvar = mint.split(parameters, [parameters.shape[1] // 2, parameters.shape[1] // 2], dim=1) + self.logvar = mint.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = mint.exp(0.5 * self.logvar) + self.var = mint.exp(self.logvar) + self.stdnormal = mint.normal + if self.deterministic: + self.var = self.std = mint.zeros_like(self.mean, dtype=self.mean.dtype) + + def sample(self): + x = self.mean + self.std * self.stdnormal(size=self.mean.shape) + return x + + def kl(self, other=None): + if self.deterministic: + return Tensor([0.0]) + else: + if other is None: + return 0.5 * mint.sum(mint.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * mint.sum( + mint.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return Tensor([0.0]) + logtwopi = ms.numpy.log(2.0 * ms.numpy.pi) + return 0.5 * mint.sum(logtwopi + self.logvar + mint.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/utils/video_utils.py b/examples/opensora_pku/opensora/models/causalvideovae/model/utils/video_utils.py index 8f2b71cf2f..3a84670188 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/utils/video_utils.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/utils/video_utils.py @@ -1,10 +1,10 @@ import numpy as np -from mindspore import ops +from mindspore import mint def tensor_to_video(x): - x = ops.clamp(x, -1, 1) + x = mint.clamp(x, -1, 1) x = (x + 1) / 2 x = x.permute(1, 0, 2, 3).float().asnumpy() # c t h w -> x = (255 * x).astype(np.uint8) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/__init__.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/__init__.py new file mode 100644 index 0000000000..92b8b07266 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/__init__.py @@ -0,0 +1,6 @@ +import logging + +from .modeling_causalvae import CausalVAEModel +from .modeling_wfvae import WFVAEModel + +logger = logging.getLogger(__name__) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_causalvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_causalvae.py new file mode 100644 index 0000000000..bc38c0f975 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_causalvae.py @@ -0,0 +1,946 @@ +import logging +import os +from typing import Tuple + +from opensora.acceleration.parallel_states import get_sequence_parallel_state + +import mindspore as ms +from mindspore import mint, nn, ops + +from mindone.diffusers import __version__ +from mindone.diffusers.models.modeling_utils import load_state_dict +from mindone.diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file + +from ..modeling_videobase import VideoBaseAE +from ..modules.conv import CausalConv3d +from ..modules.ops import nonlinearity +from ..utils.model_utils import resolve_str_to_obj + +logger = logging.getLogger(__name__) + + +class CausalVAEModel(VideoBaseAE): + """ + The default vales are set to be the same as those used in OpenSora v1.1 + """ + + def __init__( + self, + lr: float = 1e-5, # ignore + hidden_size: int = 128, + z_channels: int = 4, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = [], + dropout: float = 0.0, + resolution: int = 256, + double_z: bool = True, + embed_dim: int = 4, + num_res_blocks: int = 2, + q_conv: str = "CausalConv3d", + encoder_conv_in: str = "Conv2d", + encoder_conv_out: str = "CausalConv3d", + encoder_attention: str = "AttnBlock3DFix", + encoder_resnet_blocks: Tuple[str] = ( + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + encoder_spatial_downsample: Tuple[str] = ( + "Downsample", + "Downsample", + "Downsample", + "", + ), + encoder_temporal_downsample: Tuple[str] = ( + "", + "TimeDownsampleRes2x", + "TimeDownsampleRes2x", + "", + ), + encoder_mid_resnet: str = "ResnetBlock3D", + decoder_conv_in: str = "CausalConv3d", + decoder_conv_out: str = "CausalConv3d", + decoder_attention: str = "AttnBlock3DFix", + decoder_resnet_blocks: Tuple[str] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + decoder_spatial_upsample: Tuple[str] = ( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsampleRes2x", "TimeUpsampleRes2x"), + decoder_mid_resnet: str = "ResnetBlock3D", + use_quant_layer: bool = True, + ckpt_path=None, + ignore_keys=[], + monitor=None, + use_fp16=False, + upcast_sigmoid=False, + ): + super().__init__() + dtype = ms.float16 if use_fp16 else ms.float32 + + self.encoder = Encoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=encoder_conv_in, + conv_out=encoder_conv_out, + attention=encoder_attention, + resnet_blocks=encoder_resnet_blocks, + spatial_downsample=encoder_spatial_downsample, + temporal_downsample=encoder_temporal_downsample, + mid_resnet=encoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + double_z=double_z, + dtype=dtype, + upcast_sigmoid=upcast_sigmoid, + ) + + self.decoder = Decoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=decoder_conv_in, + conv_out=decoder_conv_out, + attention=decoder_attention, + resnet_blocks=decoder_resnet_blocks, + spatial_upsample=decoder_spatial_upsample, + temporal_upsample=decoder_temporal_upsample, + mid_resnet=decoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + dtype=dtype, + upcast_sigmoid=upcast_sigmoid, + ) + self.embed_dim = embed_dim + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + self.exp = mint.exp + self.stdnormal = mint.normal + self.depend = ops.Depend() if get_sequence_parallel_state() else None + + # self.encoder.recompute() + # self.decoder.recompute() + self.tile_sample_min_size = 256 + self.tile_sample_min_size_t = 33 + self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) + # t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] + # self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1 + self.tile_latent_min_size_t = 16 + self.tile_overlap_factor = 0.125 + self.use_tiling = False + self.use_quant_layer = use_quant_layer + if self.use_quant_layer: + quant_conv_cls = resolve_str_to_obj(q_conv) + self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) + self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) + + def get_encoder(self): + if self.use_quant_layer: + return [self.quant_conv, self.encoder] + return [self.encoder] + + def get_decoder(self): + if self.use_quant_layer: + return [self.post_quant_conv, self.decoder] + return [self.decoder] + + # rewrite class method to allow the state dict as input + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + state_dict = kwargs.pop("state_dict", None) # additional key argument + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + mindspore_dtype = kwargs.pop("mindspore_dtype", None) + subfolder = kwargs.pop("subfolder", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + + # load model + model_file = None + if from_flax: + raise NotImplementedError("loading flax checkpoint in mindspore model is not yet supported.") + else: + if state_dict is None: # edits: only search for model_file if state_dict is not provided + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + model = cls.from_config(config, **unused_kwargs) + if state_dict is None: # edits: only load model_file if state_dict is None + state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): + raise ValueError( + f"{mindspore_dtype} needs to be of type ms.Type, e.g. ms.float16, but is {type(mindspore_dtype)}." + ) + elif mindspore_dtype is not None: + model = model.to(mindspore_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.set_train(False) + if output_loading_info: + return model, loading_info + + return model + + def init_from_vae2d(self, path): + # default: tail init + # path: path to vae 2d model ckpt + vae2d_sd = ms.load_checkpoint(path) + vae_2d_keys = list(vae2d_sd.keys()) + vae_3d_keys = list(self.parameters_dict().keys()) + + # 3d -> 2d + map_dict = { + "conv.weight": "weight", + "conv.bias": "bias", + } + + new_state_dict = {} + for key_3d in vae_3d_keys: + if key_3d.startswith("loss"): + continue + + # param name mapping from vae-3d to vae-2d + key_2d = key_3d + for kw in map_dict: + key_2d = key_2d.replace(kw, map_dict[kw]) + + assert key_2d in vae_2d_keys, f"Key {key_2d} ({key_3d}) should be in 2D VAE" + + # set vae 3d state dict + shape_3d = self.parameters_dict()[key_3d].shape + shape_2d = vae2d_sd[key_2d].shape + if "bias" in key_2d: + assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" + new_state_dict[key_3d] = vae2d_sd[key_2d] + elif "norm" in key_2d: + assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" + new_state_dict[key_3d] = vae2d_sd[key_2d] + elif "conv" in key_2d or "nin_shortcut" in key_2d: + if shape_3d[:2] != shape_2d[:2]: + logger.info(key_2d, shape_3d, shape_2d) + w = vae2d_sd[key_2d] + new_w = mint.zeros(shape_3d, dtype=w.dtype) + # tail initialization + new_w[:, :, -1, :, :] = w # cin, cout, t, h, w + + new_w = ms.Parameter(new_w, name=key_3d) + + new_state_dict[key_3d] = new_w + elif "attn_1" in key_2d: + new_val = vae2d_sd[key_2d].expand_dims(axis=2) + new_param = ms.Parameter(new_val, name=key_3d) + new_state_dict[key_3d] = new_param + else: + raise NotImplementedError(f"Key {key_3d} ({key_2d}) not implemented") + + m, u = ms.load_param_into_net(self, new_state_dict) + if len(m) > 0: + logger.info("net param not loaded: ", m) + if len(u) > 0: + logger.info("checkpoint param not loaded: ", u) + + def init_from_ckpt(self, path, ignore_keys=list()): + # TODO: support auto download pretrained checkpoints + sd = ms.load_checkpoint(path) + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + logger.info("Deleting key {} from state_dict.".format(k)) + del sd[k] + + if "ema_state_dict" in sd and len(sd["ema_state_dict"]) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0: + logger.info("Load from ema model!") + sd = sd["ema_state_dict"] + sd = {key.replace("module.", ""): value for key, value in sd.items()} + elif "state_dict" in sd: + logger.info("Load from normal model!") + if "gen_model" in sd["state_dict"]: + sd = sd["state_dict"]["gen_model"] + else: + sd = sd["state_dict"] + + ms.load_param_into_net(self, sd, strict_load=False) + logger.info(f"Restored from {path}") + + def _encode(self, x): + # return latent distribution, N(mean, logvar) + h = self.encoder(x) + if self.use_quant_layer: + h = self.quant_conv(h) + mean, logvar = mint.split(h, [h.shape[1] // 2, h.shape[1] // 2], dim=1) + + return mean, logvar + + def sample(self, mean, logvar): + # sample z from latent distribution + logvar = mint.clamp(logvar, -30.0, 20.0) + std = self.exp(0.5 * logvar) + z = mean + std * self.stdnormal(size=mean.shape) + + return z + + def encode(self, x): + if self.use_tiling and ( + x.shape[-1] > self.tile_sample_min_size + or x.shape[-2] > self.tile_sample_min_size + or x.shape[-3] > self.tile_sample_min_size_t + ): + posterior_mean, posterior_logvar = self.tiled_encode(x) + else: + # embedding, get latent representation z + posterior_mean, posterior_logvar = self._encode(x) + z = self.sample(posterior_mean, posterior_logvar) + + return z + + def tiled_encode2d(self, x): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = () + tile = None + for i in range(0, x.shape[3], overlap_size): + row = () + if self.depend is not None: + x = self.depend(x, tile) + for j in range(0, x.shape[4], overlap_size): + if self.depend is not None: + x = self.depend(x, tile) + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + if self.use_quant_layer: + tile = self.quant_conv(tile) + row += (tile,) + rows += (row,) + + result_rows = () + for i, row in enumerate(rows): + result_row = () + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row += (tile[:, :, :, :row_limit, :row_limit],) + result_rows += (mint.cat(result_row, dim=4),) + + moments = mint.cat(result_rows, dim=3) + return moments + + def tiled_encode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + moments = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start:end] + if idx != 0: + moment = self.tiled_encode2d(chunk_x)[:, :, 1:] + else: + moment = self.tiled_encode2d(chunk_x) + moments.append(moment) + moments = mint.cat(moments, dim=2) + mean, logvar = mint.split(moments, [moments.shape[1] // 2, moments.shape[1] // 2], dim=1) + return mean, logvar + + def decode(self, z): + if self.use_tiling and ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + or z.shape[-3] > self.tile_latent_min_size_t + ): + return self.tiled_decode(z) + if self.use_quant_layer: + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def tiled_decode2d(self, z): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + if self.use_quant_layer: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(mint.cat(result_row, dim=4)) + + dec = mint.cat(result_rows, dim=3) + return dec + + def tiled_decode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + dec_ = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start:end] + if idx != 0: + dec = self.tiled_decode2d(chunk_x)[:, :, 1:] + else: + dec = self.tiled_decode2d(chunk_x) + dec_.append(dec) + dec_ = mint.cat(dec_, dim=2) + return dec_ + + def construct(self, input): + # overall pass, mostly for training + posterior_mean, posterior_logvar = self._encode(input) + z = self.sample(posterior_mean, posterior_logvar) + + recons = self.decode(z) + + return recons, posterior_mean, posterior_logvar + + def enable_tiling(self, use_tiling: bool = True): + self.use_tiling = use_tiling + + def disable_tiling(self): + self.enable_tiling(False) + + def blend_v(self, a: ms.Tensor, b: ms.Tensor, blend_extent: int) -> ms.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: ms.Tensor, b: ms.Tensor, blend_extent: int) -> ms.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def validation_step(self, batch_idx): + raise NotImplementedError + + +class Encoder(nn.Cell): + """ + default value aligned to v1.1 vae config.json + """ + + def __init__( + self, + z_channels: int = 4, + hidden_size: int = 128, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (), + conv_in: str = "Conv2d", + conv_out: str = "CausalConv3d", + attention: str = "AttnBlock3D", # already fixed, same as AttnBlock3DFix + resnet_blocks: Tuple[str] = ( + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + spatial_downsample: Tuple[str] = ( + "Downsample", + "Downsample", + "Downsample", + "", + ), + temporal_downsample: Tuple[str] = ( + "", + "TimeDownsampleRes2x", + "TimeDownsampleRes2x", + "", + ), + mid_resnet: str = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + double_z: bool = True, + upcast_sigmoid=False, + dtype=ms.float32, + **ignore_kwargs, + ): + """ + ch: hidden size, i.e. output channels of the first conv layer. typical: 128 + out_ch: placeholder, not used in Encoder + hidden_size_mult: channel multiply factors for each res block, also determine the number of res blocks. + Each block will be applied with spatial downsample x2 except for the last block. + In total, the spatial downsample rate = 2**(len(hidden_size_mult)-1) + resolution: spatial resolution, 256 + time_compress: the begging `time_compress` blocks will be applied with temporal downsample x2. + In total, the temporal downsample rate = 2**time_compress + """ + super().__init__() + assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks) + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + self.dtype = dtype + self.upcast_sigmoid = (upcast_sigmoid,) + + # 1. Input conv + self.conv_in_name = conv_in + if conv_in == "Conv2d": + self.conv_in = nn.Conv2d(3, hidden_size, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True) + elif conv_in == "CausalConv3d": + self.conv_in = CausalConv3d( + 3, + hidden_size, + kernel_size=3, + stride=1, + padding=1, + ) + else: + raise NotImplementedError + + # 2. Downsample + curr_res = resolution + in_ch_mult = (1,) + tuple(hidden_size_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.CellList(auto_prefix=False) + self.downsample_flag = [0] * self.num_resolutions + self.time_downsample_flag = [0] * self.num_resolutions + for i_level in range(self.num_resolutions): + block = nn.CellList() + attn = nn.CellList() + block_in = hidden_size * in_ch_mult[i_level] # input channels + block_out = hidden_size * hidden_size_mult[i_level] # output channels + for i_block in range(self.num_res_blocks): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + dtype=self.dtype, + upcast_sigmoid=upcast_sigmoid, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in, dtype=self.dtype)) + + down = nn.Cell() + down.block = block + down.attn = attn + + # do spatial downsample according to config + if spatial_downsample[i_level]: + down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in, dtype=self.dtype) + curr_res = curr_res // 2 + self.downsample_flag[i_level] = 1 + else: + # TODO: still need it for 910b in new MS version? + down.downsample = nn.Identity() + + # do temporal downsample according to config + if temporal_downsample[i_level]: + # TODO: add dtype support? + down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in) + self.time_downsample_flag[i_level] = 1 + else: + # TODO: still need it for 910b in new MS version? + down.time_downsample = nn.Identity() + + down.update_parameters_name(prefix=self.param_prefix + f"down.{i_level}.") + self.down.append(down) + + # middle + self.mid = nn.Cell() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + dtype=self.dtype, + upcast_sigmoid=upcast_sigmoid, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in, dtype=self.dtype) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + dtype=self.dtype, + upcast_sigmoid=upcast_sigmoid, + ) + self.mid.update_parameters_name(prefix=self.param_prefix + "mid.") + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + # self.norm_out = Normalize(block_in, extend=True) + + assert conv_out == "CausalConv3d", "Only CausalConv3d is supported for conv_out" + self.conv_out = resolve_str_to_obj(conv_out)( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + # copied from models.causalvideovae.model.modules.conv + def rearrange_in(self, x): + # b c f h w -> b f c h w + B, C, F, H, W = x.shape + x = mint.permute(x, (0, 2, 1, 3, 4)) + # -> (b*f c h w) + x = ops.reshape(x, (-1, C, H, W)) + + return x + + # copied from models.causalvideovae.model.modules.conv + def rearrange_out(self, x, F): + BF, D, H_, W_ = x.shape + # (b*f D h w) -> (b f D h w) + x = ops.reshape(x, (BF // F, F, D, H_, W_)) + # -> (b D f h w) + x = mint.permute(x, (0, 2, 1, 3, 4)) + + return x + + def construct(self, x): + # downsampling + if self.conv_in_name != "Conv2d": + hs = self.conv_in(x) + else: + F = x.shape[-3] + x = self.rearrange_in(x) + x = self.conv_in(x) + hs = self.rearrange_out(x, F) + + h = hs + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + # import pdb; pdb.set_trace() + h = self.down[i_level].block[i_block](hs) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs = h + # if hasattr(self.down[i_level], "downsample"): + # if not isinstance(self.down[i_level].downsample, nn.Identity): + if self.downsample_flag[i_level]: + hs = self.down[i_level].downsample(hs) + # if hasattr(self.down[i_level], "time_downsample"): + # if not isinstance(self.down[i_level].time_downsample, nn.Identity): + if self.time_downsample_flag[i_level]: + hs_down = self.down[i_level].time_downsample(hs) + hs = hs_down + + # middle + # h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h, upcast=self.upcast_sigmoid) + h = self.conv_out(h) + return h + + +class Decoder(nn.Cell): + """ + default value aligned to v1.1 vae config.json + """ + + def __init__( + self, + z_channels: int = 4, + hidden_size: int = 128, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (), + conv_in: str = "CausalConv3d", + conv_out: str = "CausalConv3d", + attention: str = "AttnBlock3D", # already fixed, same as AttnBlock3DFix + resnet_blocks: Tuple[str] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + spatial_upsample: Tuple[str] = ("", "SpatialUpsample2x", "SpatialUpsample2x", "SpatialUpsample2x"), + temporal_upsample: Tuple[str] = ("", "", "TimeUpsampleRes2x", "TimeUpsampleRes2x"), + mid_resnet: str = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + double_z: bool = True, + upcast_sigmoid=False, + dtype=ms.float32, + **ignore_kwargs, + ): + super().__init__() + + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + self.dtype = dtype + self.upcast_sigmoid = upcast_sigmoid + + # 1. decode input z conv + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # self.z_shape = (1, z_channels, curr_res, curr_res) + # logger.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + assert conv_in == "CausalConv3d", "Only CausalConv3d is supported for conv_in in Decoder currently" + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, padding=1) + + # 2. middle + self.mid = nn.Cell() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, out_channels=block_in, dropout=dropout, dtype=self.dtype + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in, dtype=self.dtype) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, out_channels=block_in, dropout=dropout, dtype=self.dtype + ) + self.mid.update_parameters_name(prefix=self.param_prefix + "mid.") + + # 3. upsampling + self.up = nn.CellList(auto_prefix=False) + self.upsample_flag = [0] * self.num_resolutions + self.time_upsample_flag = [0] * self.num_resolutions + # i_level: 3 -> 2 -> 1 -> 0 + for i_level in reversed(range(self.num_resolutions)): + block = nn.CellList() + attn = nn.CellList() + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + dtype=self.dtype, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in, dtype=self.dtype)) + up = nn.Cell() + up.block = block + up.attn = attn + # do spatial upsample x2 except for the first block + if spatial_upsample[i_level]: + up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in, dtype=self.dtype) + curr_res = curr_res * 2 + self.upsample_flag[i_level] = 1 + else: + up.upsample = nn.Identity() + # do temporal upsample x2 in the bottom tc blocks + if temporal_upsample[i_level]: + # TODO: support dtype? + up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in) + self.time_upsample_flag[i_level] = 1 + else: + up.time_upsample = nn.Identity() + + up.update_parameters_name(prefix=self.param_prefix + f"up.{i_level}.") + if len(self.up) != 0: + self.up.insert(0, up) + else: + self.up.append(up) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + # self.norm_out = Normalize(block_in, extend=True) + + assert conv_out == "CausalConv3d", "Only CausalConv3d is supported for conv_out in Decoder currently" + self.conv_out = CausalConv3d(block_in, 3, kernel_size=3, padding=1) + + def construct(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + i_level = self.num_resolutions + while i_level > 0: + i_level -= 1 + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + # if hasattr(self.up[i_level], 'upsample'): + # if not isinstance(self.up[i_level].upsample, nn.Identity): + if self.upsample_flag[i_level]: + h = self.up[i_level].upsample(h) + + # if hasattr(self.up[i_level], 'time_upsample'): + # if not isinstance(self.up[i_level].time_upsample, nn.Identity): + if self.time_upsample_flag[i_level]: + h = self.up[i_level].time_upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h, upcast=self.upcast_sigmoid) + h = self.conv_out(h) + return h diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py new file mode 100644 index 0000000000..32a1139bd9 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -0,0 +1,799 @@ +import logging +import os +from typing import List + +from opensora.npu_config import npu_config + +import mindspore as ms +from mindspore import mint, nn + +from mindone.diffusers import __version__ +from mindone.diffusers.configuration_utils import register_to_config +from mindone.diffusers.models.modeling_utils import load_state_dict +from mindone.diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file + +from ..modeling_videobase import VideoBaseAE +from ..modules import ( + AttnBlock3DFix, + CausalConv3d, + Conv2d, + HaarWaveletTransform3D, + InverseHaarWaveletTransform3D, + Normalize, + ResnetBlock2D, + ResnetBlock3D, + nonlinearity, +) +from ..registry import ModelRegistry +from ..utils.model_utils import resolve_str_to_obj + +logger = logging.getLogger(__name__) + + +class Encoder(VideoBaseAE): + @register_to_config + def __init__( + self, + latent_dim: int = 8, + base_channels: int = 128, + num_resblocks: int = 2, + energy_flow_hidden_size: int = 64, + dropout: float = 0.0, + use_attention: bool = True, + norm_type: str = "groupnorm", + l1_dowmsample_block: str = "Downsample", + l1_downsample_wavelet: str = "HaarWaveletTransform2D", + l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample", + l2_downsample_wavelet: str = "HaarWaveletTransform3D", + dtype=ms.float32, + ) -> None: + super().__init__() + + self.down1 = nn.SequentialCell( + Conv2d( + 24, + base_channels, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ).to_float(dtype), + *[ + ResnetBlock2D( + in_channels=base_channels, + out_channels=base_channels, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ) + for _ in range(num_resblocks) + ], + resolve_str_to_obj(l1_dowmsample_block)(in_channels=base_channels, out_channels=base_channels, dtype=dtype), + ) + self.down2 = nn.SequentialCell( + Conv2d( + base_channels + energy_flow_hidden_size, + base_channels * 2, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ).to_float(dtype), + *[ + ResnetBlock3D( + in_channels=base_channels * 2, + out_channels=base_channels * 2, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ) + for _ in range(num_resblocks) + ], + resolve_str_to_obj(l2_dowmsample_block)(base_channels * 2, base_channels * 2, dtype=dtype), + ) + # Connection + if l1_dowmsample_block == "Downsample": # Bad code. For temporal usage. + l1_channels = 12 + else: + l1_channels = 24 + + self.connect_l1 = Conv2d( + l1_channels, + energy_flow_hidden_size, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ).to_float(dtype) + self.connect_l2 = Conv2d( + 24, + energy_flow_hidden_size, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ).to_float(dtype) + # Mid + mid_layers = [ + ResnetBlock3D( + in_channels=base_channels * 2 + energy_flow_hidden_size, + out_channels=base_channels * 4, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ), + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ), + ] + if use_attention: + mid_layers.insert(1, AttnBlock3DFix(in_channels=base_channels * 4, norm_type=norm_type, dtype=dtype)) + self.mid = nn.SequentialCell(*mid_layers) + self.norm_out = Normalize(base_channels * 4, norm_type=norm_type) + self.conv_out = CausalConv3d(base_channels * 4, latent_dim * 2, kernel_size=3, stride=1, padding=1) + + self.wavelet_tranform_l1 = resolve_str_to_obj(l1_downsample_wavelet)(dtype=dtype) + self.wavelet_tranform_l2 = resolve_str_to_obj(l2_downsample_wavelet)(dtype=dtype) + + def construct(self, coeffs): + l1_coeffs = coeffs[:, :3] + l1_coeffs = self.wavelet_tranform_l1(l1_coeffs) + l1 = self.connect_l1(l1_coeffs) + l2_coeffs = self.wavelet_tranform_l2(l1_coeffs[:, :3]) + l2 = self.connect_l2(l2_coeffs) + + h = self.down1(coeffs) + h = mint.cat([h, l1], dim=1) + h = self.down2(h) + h = mint.cat([h, l2], dim=1) + h = self.mid(h) + + if npu_config is None: + h = self.norm_out(h) + else: + h = npu_config.run_group_norm(self.norm_out, h) + + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(VideoBaseAE): + @register_to_config + def __init__( + self, + latent_dim: int = 8, + base_channels: int = 128, + num_resblocks: int = 2, + dropout: float = 0.0, + energy_flow_hidden_size: int = 128, + use_attention: bool = True, + norm_type: str = "groupnorm", + t_interpolation: str = "nearest", + connect_res_layer_num: int = 1, + l1_upsample_block: str = "Upsample", + l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D", + l2_upsample_block: str = "Spatial2xTime2x3DUpsample", + l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D", + dtype=ms.float32, + ) -> None: + super().__init__() + + self.energy_flow_hidden_size = energy_flow_hidden_size + + self.conv_in = CausalConv3d(latent_dim, base_channels * 4, kernel_size=3, stride=1, padding=1) + mid_layers = [ + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ), + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4 + energy_flow_hidden_size, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ), + ] + if use_attention: + mid_layers.insert(1, AttnBlock3DFix(in_channels=base_channels * 4, norm_type=norm_type, dtype=dtype)) + self.mid = nn.SequentialCell(*mid_layers) + self.up2 = nn.SequentialCell( + *[ + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ) + for _ in range(num_resblocks) + ], + resolve_str_to_obj(l2_upsample_block)( + base_channels * 4, base_channels * 4, t_interpolation=t_interpolation, dtype=dtype + ), + ResnetBlock3D( + in_channels=base_channels * 4, + out_channels=base_channels * 4 + energy_flow_hidden_size, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ), + ) + self.up1 = nn.SequentialCell( + *[ + ResnetBlock3D( + in_channels=base_channels * (4 if i == 0 else 2), + out_channels=base_channels * 2, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ) + for i in range(num_resblocks) + ], + resolve_str_to_obj(l1_upsample_block)( + in_channels=base_channels * 2, out_channels=base_channels * 2, dtype=dtype + ), + ResnetBlock3D( + in_channels=base_channels * 2, + out_channels=base_channels * 2, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ), + ) + self.layer = nn.SequentialCell( + *[ + ResnetBlock3D( + in_channels=base_channels * (2 if i == 0 else 1), + out_channels=base_channels, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ) + for i in range(2) + ], + ) + # Connection + if l1_upsample_block == "Upsample": # Bad code. For temporal usage. + l1_channels = 12 + else: + l1_channels = 24 + self.connect_l1 = nn.SequentialCell( + *[ + ResnetBlock3D( + in_channels=base_channels, + out_channels=base_channels, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ) + for _ in range(connect_res_layer_num) + ], + Conv2d( + base_channels, + l1_channels, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ).to_float(dtype), + ) + self.connect_l2 = nn.SequentialCell( + *[ + ResnetBlock3D( + in_channels=base_channels, + out_channels=base_channels, + dropout=dropout, + norm_type=norm_type, + dtype=dtype, + ) + for _ in range(connect_res_layer_num) + ], + Conv2d( + base_channels, + 24, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ).to_float(dtype), + ) + # Out + self.norm_out = Normalize(base_channels, norm_type=norm_type) + self.conv_out = Conv2d( + base_channels, + 24, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ).to_float(dtype) + + self.inverse_wavelet_tranform_l1 = resolve_str_to_obj(l1_upsample_wavelet)(dtype=dtype) + self.inverse_wavelet_tranform_l2 = resolve_str_to_obj(l2_upsample_wavelet)(dtype=dtype) + + def construct(self, z): + h = self.conv_in(z) + h = self.mid(h) + l2_coeffs = self.connect_l2(h[:, -self.energy_flow_hidden_size :]) + l2 = self.inverse_wavelet_tranform_l2(l2_coeffs) + + h = self.up2(h[:, : -self.energy_flow_hidden_size]) + + l1_coeffs = h[:, -self.energy_flow_hidden_size :] + l1_coeffs = self.connect_l1(l1_coeffs) + l1_coeffs[:, :3] = l1_coeffs[:, :3] + l2 + l1 = self.inverse_wavelet_tranform_l1(l1_coeffs) + + h = self.up1(h[:, : -self.energy_flow_hidden_size]) + + h = self.layer(h) + h = npu_config.run_group_norm(self.norm_out, h) + h = nonlinearity(h) + h = self.conv_out(h) + h[:, :3] = h[:, :3] + l1 + return h + + +@ModelRegistry.register("WFVAE") +class WFVAEModel(VideoBaseAE): + @register_to_config + def __init__( + self, + latent_dim: int = 8, + base_channels: int = 128, + encoder_num_resblocks: int = 2, + encoder_energy_flow_hidden_size: int = 64, + decoder_num_resblocks: int = 2, + decoder_energy_flow_hidden_size: int = 128, + use_attention: bool = True, + dropout: float = 0.0, + norm_type: str = "groupnorm", + t_interpolation: str = "nearest", + connect_res_layer_num: int = 1, + scale: List[float] = [0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215, 0.18215], + shift: List[float] = [0, 0, 0, 0, 0, 0, 0, 0], + # Module config + l1_dowmsample_block: str = "Downsample", + l1_downsample_wavelet: str = "HaarWaveletTransform2D", + l2_dowmsample_block: str = "Spatial2xTime2x3DDownsample", + l2_downsample_wavelet: str = "HaarWaveletTransform3D", + l1_upsample_block: str = "Upsample", + l1_upsample_wavelet: str = "InverseHaarWaveletTransform2D", + l2_upsample_block: str = "Spatial2xTime2x3DUpsample", + l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D", + dtype=ms.float32, + ) -> None: + super().__init__() + self.use_tiling = False + + # Hardcode for now + self.t_chunk_enc = 16 + self.t_upsample_times = 4 // 2 + self.t_chunk_dec = 4 + self.use_quant_layer = False + self.encoder = Encoder( + latent_dim=latent_dim, + base_channels=base_channels, + num_resblocks=encoder_num_resblocks, + energy_flow_hidden_size=encoder_energy_flow_hidden_size, + dropout=dropout, + use_attention=use_attention, + norm_type=norm_type, + l1_dowmsample_block=l1_dowmsample_block, + l1_downsample_wavelet=l1_downsample_wavelet, + l2_dowmsample_block=l2_dowmsample_block, + l2_downsample_wavelet=l2_downsample_wavelet, + dtype=dtype, + ) + self.decoder = Decoder( + latent_dim=latent_dim, + base_channels=base_channels, + num_resblocks=decoder_num_resblocks, + energy_flow_hidden_size=decoder_energy_flow_hidden_size, + dropout=dropout, + use_attention=use_attention, + norm_type=norm_type, + t_interpolation=t_interpolation, + connect_res_layer_num=connect_res_layer_num, + l1_upsample_block=l1_upsample_block, + l1_upsample_wavelet=l1_upsample_wavelet, + l2_upsample_block=l2_upsample_block, + l2_upsample_wavelet=l2_upsample_wavelet, + dtype=dtype, + ) + + # Set cache offset for trilinear lossless upsample. + self._set_cache_offset([self.decoder.up2, self.decoder.connect_l2, self.decoder.conv_in, self.decoder.mid], 1) + self._set_cache_offset( + [self.decoder.up2[-2:], self.decoder.up1, self.decoder.connect_l1, self.decoder.layer], + self.t_upsample_times, + ) + + self.exp = mint.exp + self.stdnormal = mint.normal + + self.update_parameters_name() # update parameter names to solve pname mismatch + + def get_encoder(self): + if self.use_quant_layer: + return [self.quant_conv, self.encoder] + return [self.encoder] + + def get_decoder(self): + if self.use_quant_layer: + return [self.post_quant_conv, self.decoder] + return [self.decoder] + + def _empty_causal_cached(self, parent): + for name, module in parent.cells_and_names(): + if hasattr(module, "causal_cached"): + module.causal_cached = None + + def _set_causal_cached(self, enable_cached=True): + for name, module in self.cells_and_names(): + if hasattr(module, "enable_cached"): + module.enable_cached = enable_cached + + def _set_cache_offset(self, modules, cache_offset=0): + for module in modules: + for submodule in module.cells(): + if hasattr(submodule, "cache_offset"): + submodule.cache_offset = cache_offset + + def build_chunk_start_end(self, t, decoder_mode=False): + start_end = [[0, 1]] + start = 1 + end = start + while True: + if start >= t: + break + end = min(t, end + (self.t_chunk_dec if decoder_mode else self.t_chunk_enc)) + start_end.append([start, end]) + start = end + return start_end + + def encode(self, x, sample_posterior=True): + posterior_mean, posterior_logvar = self._encode(x) + if sample_posterior: + z = self.sample(posterior_mean, posterior_logvar) + else: + z = posterior_mean + + return z + + def _encode(self, x): + self._empty_causal_cached(self.encoder) + + coeffs = HaarWaveletTransform3D()(x) + + if self.use_tiling: + h = self.tile_encode(coeffs) + else: + h = self.encoder(coeffs) + if self.use_quant_layer: + h = self.quant_conv(h) + posterior_mean, posterior_logvar = mint.split(h, [h.shape[1] // 2, h.shape[1] // 2], dim=1) + return posterior_mean, posterior_logvar + + def tile_encode(self, x): + b, c, t, h, w = x.shape + + start_end = self.build_chunk_start_end(t) + result = [] + for start, end in start_end: + chunk = x[:, :, start:end, :, :] + chunk = self.encoder(chunk) + if self.use_quant_layer: + chunk = self.encoder(chunk) + result.append(chunk) + + return mint.cat(result, dim=2) + + def decode(self, z): + self._empty_causal_cached(self.decoder) + + if self.use_tiling: + dec = self.tile_decode(z) + else: + if self.use_quant_layer: + z = self.post_quant_conv(z) + dec = self.decoder(z) + + dec = InverseHaarWaveletTransform3D()(dec) + return dec + + def tile_decode(self, x): + b, c, t, h, w = x.shape + + start_end = self.build_chunk_start_end(t, decoder_mode=True) + + result = [] + for start, end in start_end: + if end + 1 < t: + chunk = x[:, :, start : end + 1, :, :] + else: + chunk = x[:, :, start:end, :, :] + + if self.use_quant_layer: + chunk = self.post_quant_conv(chunk) + chunk = self.decoder(chunk) + + if end + 1 < t: + chunk = chunk[:, :, :-2] + result.append(chunk) + else: + result.append(chunk) + + return mint.cat(result, dim=2) + + def sample(self, mean, logvar): + # sample z from latent distribution + logvar = mint.clamp(logvar, -30.0, 20.0) + std = self.exp(0.5 * logvar) + z = mean + std * self.stdnormal(size=mean.shape) + + return z + + def construct(self, input, sample_posterior=True): + # overall pass, mostly for training + posterior_mean, posterior_logvar = self._encode(input) + if sample_posterior: + z = self.sample(posterior_mean, posterior_logvar) + else: + z = posterior_mean + + recons = self.decode(z) + + return recons, posterior_mean, posterior_logvar + + def get_last_layer(self): + if hasattr(self.decoder.conv_out, "conv"): + return self.decoder.conv_out.conv.weight + else: + return self.decoder.conv_out.weight + + def enable_tiling(self, use_tiling: bool = True): + self.use_tiling = use_tiling + self._set_causal_cached(use_tiling) + + def disable_tiling(self): + self.enable_tiling(False) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + state_dict = kwargs.pop("state_dict", None) # additional key argument + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + mindspore_dtype = kwargs.pop("mindspore_dtype", None) + subfolder = kwargs.pop("subfolder", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + ignore_prefix = kwargs.pop("ignore_prefix", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + **kwargs, + ) + + # load model + model_file = None + if from_flax: + raise NotImplementedError("loading flax checkpoint in mindspore model is not yet supported.") + else: + if state_dict is None: # edits: only search for model_file if state_dict is not provided + if use_safetensors: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + model = cls.from_config(config, **unused_kwargs) + if state_dict is None: # edits: only load model_file if state_dict is None + state_dict = load_state_dict(model_file, variant=variant) + if ignore_prefix is not None: + assert len(ignore_prefix) > 0, "the ignore_prefix must not be empty" + num_params = len(state_dict) + state_dict = dict( + [ + (k, v) + for k, v in state_dict.items() + if not any([k.startswith(prefix) for prefix in ignore_prefix]) + ] + ) + logger.info( + f"Excluding the parameters with prefix in {ignore_prefix}: exclude {num_params - len(state_dict)} out of {num_params} params" + ) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): + raise ValueError( + f"{mindspore_dtype} needs to be of type ms.Type, e.g. ms.float16, but is {type(mindspore_dtype)}." + ) + elif mindspore_dtype is not None: + model = model.to(mindspore_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.set_train(False) + if output_loading_info: + return model, loading_info + + return model + + def init_from_vae2d(self, path): + # default: tail init + # path: path to vae 2d model ckpt + vae2d_sd = ms.load_checkpoint(path) + vae_2d_keys = list(vae2d_sd.keys()) + vae_3d_keys = list(self.parameters_dict().keys()) + + # 3d -> 2d + map_dict = { + "conv.weight": "weight", + "conv.bias": "bias", + } + + new_state_dict = {} + for key_3d in vae_3d_keys: + if key_3d.startswith("loss"): + continue + + # param name mapping from vae-3d to vae-2d + key_2d = key_3d + for kw in map_dict: + key_2d = key_2d.replace(kw, map_dict[kw]) + + assert key_2d in vae_2d_keys, f"Key {key_2d} ({key_3d}) should be in 2D VAE" + + # set vae 3d state dict + shape_3d = self.parameters_dict()[key_3d].shape + shape_2d = vae2d_sd[key_2d].shape + if "bias" in key_2d: + assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" + new_state_dict[key_3d] = vae2d_sd[key_2d] + elif "norm" in key_2d: + assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" + new_state_dict[key_3d] = vae2d_sd[key_2d] + elif "conv" in key_2d or "nin_shortcut" in key_2d: + if shape_3d[:2] != shape_2d[:2]: + logger.info(key_2d, shape_3d, shape_2d) + w = vae2d_sd[key_2d] + new_w = mint.zeros(shape_3d, dtype=w.dtype) + # tail initialization + new_w[:, :, -1, :, :] = w # cin, cout, t, h, w + + new_w = ms.Parameter(new_w, name=key_3d) + + new_state_dict[key_3d] = new_w + elif "attn_1" in key_2d: + new_val = vae2d_sd[key_2d].expand_dims(axis=2) + new_param = ms.Parameter(new_val, name=key_3d) + new_state_dict[key_3d] = new_param + else: + raise NotImplementedError(f"Key {key_3d} ({key_2d}) not implemented") + + m, u = ms.load_param_into_net(self, new_state_dict) + if len(m) > 0: + logger.info("net param not loaded: ", m) + if len(u) > 0: + logger.info("checkpoint param not loaded: ", u) + + def init_from_ckpt(self, path, ignore_keys=list()): + # TODO: support auto download pretrained checkpoints + sd = ms.load_checkpoint(path) + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + logger.info("Deleting key {} from state_dict.".format(k)) + del sd[k] + + if "ema_state_dict" in sd and len(sd["ema_state_dict"]) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0: + logger.info("Load from ema model!") + sd = sd["ema_state_dict"] + sd = {key.replace("module.", ""): value for key, value in sd.items()} + elif "state_dict" in sd: + logger.info("Load from normal model!") + if "gen_model" in sd["state_dict"]: + sd = sd["state_dict"]["gen_model"] + else: + sd = sd["state_dict"] + + ms.load_param_into_net(self, sd, strict_load=False) + logger.info(f"Restored from {path}") diff --git a/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py b/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py new file mode 100644 index 0000000000..e18916b615 --- /dev/null +++ b/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py @@ -0,0 +1,161 @@ +import argparse +import os +import sys + +from tqdm import tqdm + +sys.path.append(".") +from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info +from opensora.models.causalvideovae.model import ModelRegistry +from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader +from opensora.utils.ms_utils import init_env +from opensora.utils.utils import get_precision + +from mindone.visualize.videos import save_videos + + +def main(args: argparse.Namespace): + rank_id, device_num = init_env( + args.mode, + seed=args.seed, + distributed=args.use_parallel, + device_target=args.device, + max_device_memory=args.max_device_memory, + parallel_mode=args.parallel_mode, + precision_mode=args.precision_mode, + sp_size=args.sp_size, + jit_level=args.jit_level, + jit_syntax_level=args.jit_syntax_level, + ) + + real_video_dir = args.real_video_dir + generated_video_dir = args.generated_video_dir + sample_rate = args.sample_rate + height, width = args.resolution, args.resolution + crop_size = (height, width) if args.crop_size is None else (args.crop_size, args.crop_size) + num_frames = args.num_frames + sample_rate = args.sample_rate + + sample_fps = args.sample_fps + batch_size = args.batch_size + num_workers = args.num_workers + subset_size = args.subset_size + + if not os.path.exists(args.generated_video_dir): + os.makedirs(args.generated_video_dir, exist_ok=True) + + data_type = get_precision(args.precision) + + # ---- Load Model ---- + model_cls = ModelRegistry.get_model(args.model_name) + vae = model_cls.from_pretrained(args.from_pretrained) + vae = vae.to(data_type) + if args.enable_tiling: + vae.enable_tiling() + vae.tile_overlap_factor = args.tile_overlap_factor + + # ---- Prepare Dataset ---- + ds_config = dict( + data_folder=real_video_dir, + size=(height, width), + crop_size=crop_size, + disable_flip=True, + random_crop=False, + sample_stride=sample_rate, + sample_n_frames=num_frames, + dynamic_start_index=args.dynamic_start_index, + ) + dataset = VideoDataset(**ds_config) + if subset_size: + indices = range(subset_size) + dataset.dataset = [dataset.dataset[i] for i in indices] + dataset.length = len(dataset) + + dataloader = create_dataloader( + dataset, + batch_size=batch_size, + ds_name="video", + num_parallel_workers=num_workers, + shuffle=False, # be in order + device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), + rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, + drop_remainder=False, + ) + + # ---- Inference ---- + for batch in tqdm(dataloader): + x, file_names = batch["video"], batch["file_name"] + x = x.to(dtype=data_type) # b c t h w + x = x * 2 - 1 + encode_result = vae.encode(x) + if isinstance(encode_result, tuple): + encode_result = encode_result[0] + latents = encode_result.to(data_type) + video_recon = vae.decode(latents) + if isinstance(video_recon, tuple): + video_recon = video_recon[0] + for idx, video in enumerate(video_recon): + output_path = os.path.join(generated_video_dir, file_names[idx]) + if args.output_origin: + os.makedirs(os.path.join(generated_video_dir, "origin/"), exist_ok=True) + origin_output_path = os.path.join(generated_video_dir, "origin/", file_names[idx]) + save_videos(x[idx], origin_output_path, loop=0, fps=sample_fps / sample_rate) + + save_videos( + video, + output_path, + loop=0, + fps=sample_fps / sample_rate, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--real_video_dir", type=str, default="") + parser.add_argument("--generated_video_dir", type=str, default="") + parser.add_argument("--from_pretrained", type=str, default="") + parser.add_argument("--sample_fps", type=int, default=30) + parser.add_argument("--resolution", type=int, default=336) + parser.add_argument("--crop_size", type=int, default=None) + parser.add_argument("--num_frames", type=int, default=17) + parser.add_argument("--sample_rate", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--subset_size", type=int, default=None) + parser.add_argument("--tile_overlap_factor", type=float, default=0.25) + parser.add_argument("--enable_tiling", action="store_true") + parser.add_argument("--output_origin", action="store_true") + parser.add_argument("--model_name", type=str, default=None, help="") + parser.add_argument( + "--dynamic_start_index", + action="store_true", + help="Whether to use a random frame as the starting frame for reconstruction. Default is False for the ease of evaluation.", + ) + parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument( + "--precision", + default="bf16", + type=str, + choices=["fp32", "fp16", "bf16"], + help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \ + if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.", + ) + parser.add_argument("--device", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") + parser.add_argument("--use_parallel", action="store_true", help="use parallel") + parser.add_argument( + "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data, optim" + ) + parser.add_argument("--jit_level", default="O0", help="Set jit level: # O0: KBK, O1:DVM, O2: GE") + parser.add_argument( + "--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax" + ) + parser.add_argument("--seed", type=int, default=4, help="Inference seed") + parser.add_argument( + "--precision_mode", + default=None, + type=str, + help="If specified, set the precision mode for Ascend configurations.", + ) + args = parser.parse_args() + main(args) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index 147e0a1c0a..0a05a21167 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -6,6 +6,7 @@ import numpy as np from opensora.acceleration.communications import AllToAll_SBH from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info +from opensora.npu_config import npu_config import mindspore as ms from mindspore import Parameter, mint, nn, ops @@ -765,7 +766,12 @@ def construct(self, latent, num_frames): b, c, t, h, w = latent.shape video_latent, image_latent = None, None height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - latent = self.proj(latent) + + if npu_config is not None and npu_config.on_npu: + latent = npu_config.run_conv3d(self.proj, latent, latent.dtype) + else: + latent = self.proj(latent) + if self.flatten: # b c t h w -> (b t) (h w) c latent = latent.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c) @@ -1007,7 +1013,11 @@ def construct(self, x, attention_mask, t, h, w): x = x.reshape(b, t, h, w, -1).permute(0, 4, 1, 2, 3) x_dtype = x.dtype - x = self.layer(x).to(x_dtype) + (x if self.down_shortcut else 0) + if npu_config is not None and npu_config.on_npu: + conv_out = npu_config.run_conv3d(self.layer, x, x_dtype) + else: + conv_out = self.layer(x) + x = conv_out + (x if self.down_shortcut else 0) # b d (t dt) (h dh) (w dw) -> (b dt dh dw) (t h w) d dt, dh, dw = self.down_factor diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py new file mode 100644 index 0000000000..2afde53e51 --- /dev/null +++ b/examples/opensora_pku/opensora/npu_config.py @@ -0,0 +1,344 @@ +import bisect +import gc +import math +import os +import subprocess +from contextlib import contextmanager + +from opensora.utils.ms_utils import init_env + +import mindspore as ms +from mindspore import mint, ops + +from mindone.utils.version_control import choose_flash_attention_dtype + + +@contextmanager +def set_run_dtype(x, dtype=None): + # 保存原始环境变量的值(如果存在) + npu_config.original_run_dtype = x.dtype + # 设置环境变量为指定的值 + npu_config.current_run_dtype = dtype + try: + # Yield control back to the body of the `with` statement + yield + finally: + # 恢复原始的环境变量值 + npu_config.current_run_dtype = None + npu_config.original_run_dtype = None + + +class NPUConfig: + N_NPU_PER_NODE = 8 + + def __init__(self): + self.on_npu = True + self.node_world_size = self.N_NPU_PER_NODE + self.profiling = False + self.profiling_step = 5 + self.enable_FA = True + self.enable_FP32 = False + self.load_pickle = True + self.use_small_dataset = False + self.current_run_dtype = None + self.original_run_dtype = None + + self.replaced_type = ms.float32 + self.conv_dtype = ms.float16 + if self.enable_FA and self.enable_FP32: + self.inf_float = -10000.0 + else: + self.inf_float = -10000.0 + + if self.use_small_dataset: + self.load_pickle = False + + self._loss = [] + self.work_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + self.pickle_save_path = f"{self.work_path}/pickles" + + gc.set_threshold(700, 10, 10000) + self.fa_mask_dtype = choose_flash_attention_dtype() + self.flash_attn_valid_head_dims = [64, 80, 96, 120, 128, 256] + self.FA_dtype = ms.bfloat16 + assert self.FA_dtype in [ms.float16, ms.bfloat16], f"Unsupported flash-attention dtype: {self.FA_dtype}" + + def set_npu_env(self, args): + rank_id, device_num = init_env( + mode=args.mode, + device_target=args.device, + precision_mode=getattr(args, "precision_mode", None), + jit_level=getattr(args, "jit_level", None), + jit_syntax_level=getattr(args, "jit_syntax_level", "strict"), + ) + self.rank = rank_id + self.bind_thread_to_cpu() + return rank_id, device_num + + def get_total_cores(self): + try: + total_cores = os.sysconf("SC_NPROCESSORS_ONLN") + except (AttributeError, ValueError): + total_cores = os.cpu_count() + return total_cores + + def bind_thread_to_cpu(self): + total_cores = self.get_total_cores() + # 每个卡的核心数量 + cores_per_rank = total_cores // 8 + # 计算本地rank + local_rank = self.rank % 8 + # 计算当前 rank 的 CPU 核范围 + start_core = local_rank * cores_per_rank + end_core = start_core + cores_per_rank - 1 + # 构建 CPU 核范围字符串 + cpu_cores_range = f"{start_core}-{end_core}" + pid = os.getpid() + command = f"taskset -cp {cpu_cores_range} {pid}" + + subprocess.run(command, shell=True, check=True) + return f"Binding Cores: {self.rank} {pid} {cpu_cores_range}" + + def get_attention_mask(self, attention_mask, repeat_num): + if self.on_npu and attention_mask is not None: + if npu_config.enable_FA: + attention_mask = attention_mask.to(ms.float16) + attention_mask = attention_mask.repeat_interleave(repeat_num, dim=-2) + return attention_mask + + def set_current_run_dtype(self, variables): + if variables[0].dtype != self.current_run_dtype and self.current_run_dtype is not None: + for index, var in enumerate(variables): + variables[index] = var.to(self.current_run_dtype) + return tuple(variables) + + def restore_dtype(self, x): + if x.dtype != self.original_run_dtype and self.original_run_dtype is not None: + x = x.to(self.original_run_dtype) + return x + + def get_node_id(self): + return self.rank // self.node_world_size + + def get_node_size(self): + return self.world_size // self.node_world_size + + def get_local_rank(self): + return self.rank % self.N_NPU_PER_NODE + + def _run(self, operator, x, tmp_dtype, out_dtype=None): + if self.on_npu: + if out_dtype is None: + out_dtype = x.dtype + x = operator.to_float(tmp_dtype)(x.to(tmp_dtype)) + x = x.to(out_dtype) + return x + else: + return operator(x) + + def run_group_norm(self, operator, x): + return self._run(operator, x, ms.float32) + + def run_layer_norm(self, operator, x): + return self._run(operator, x, ms.float32) + + def run_batch_norm(self, operator, x): + return self._run(operator, x, ms.float32) + + def run_conv3d(self, operator, x, out_dtype): + return self._run(operator, x, self.conv_dtype, out_dtype) + + def run_pool_2d(self, operator, x, kernel_size, stride): + if self.on_npu: + x_dtype = x.dtype + x = x.to(self.replaced_type) + x = operator(x, kernel_size=kernel_size, stride=stride) + x = x.to(x_dtype) + else: + x = operator(x, kernel_size=kernel_size, stride=stride) + return x + + def run_pad_2d(self, operator, x, pad, mode="constant"): + if self.on_npu: + x_dtype = x.dtype + x = x.to(self.replaced_type) + x = operator(x, pad, mode) + x = x.to(x_dtype) + else: + x = operator(x, pad, mode) + return x + + def run_interpolate(self, operator, x, scale_factor=None): + if self.on_npu: + x_dtype = x.dtype + x = x.to(self.replaced_type) + x = operator(x, scale_factor=scale_factor) + x = x.to(x_dtype) + else: + x = operator(x, scale_factor=scale_factor) + return x + + def run_attention(self, query, key, value, attention_mask, input_layout, head_dim, head_num): + if self.enable_FA: + hidden_states = self.ms_flash_attention( + query, + key, + value, + attention_mask=attention_mask, + input_layout=input_layout, + scale=1 / math.sqrt(head_dim), + head_num=head_num, + ) + else: + hidden_states = self.scaled_dot_product_attention( + query, + key, + value, + attention_mask=attention_mask, + input_layout=input_layout, + scale=1 / math.sqrt(head_dim), + head_num=head_num, + ) + return hidden_states + + def ms_flash_attention( + self, + query, + key, + value, + attention_mask, + head_num, + scale, + input_layout="BSH", + attention_dropout: float = 0.0, + ): + # Memory efficient attention on mindspore uses flash attention under the hoods. + # Flash attention implementation is called `FlashAttentionScore` + # which is an experimental api with the following limitations: + # 1. Sequence length of query must be divisible by 16 and in range of [1, 32768]. + # 2. Head dimensions must be one of [64, 80, 96, 120, 128, 256]. + # 3. The input dtype must be float16 or bfloat16. + # Sequence length of query must be checked in runtime. + if input_layout not in ["BSH", "BNSD"]: + raise ValueError(f"input_layout must be in ['BSH', 'BNSD'], but get {input_layout}.") + Bs, query_tokens, inner_dim = query.shape + assert query_tokens % 16 == 0, f"Sequence length of query must be divisible by 16, but got {query_tokens}." + key_tokens = key.shape[1] + heads = head_num + query = query.view(Bs, query_tokens, heads, -1) + key = key.view(Bs, key_tokens, heads, -1) + value = value.view(Bs, key_tokens, heads, -1) + + head_dim = inner_dim // heads + if head_dim in self.flash_attn_valid_head_dims: + head_dim_padding = 0 + else: + minimum_larger_index = bisect.bisect_right(self.flash_attn_valid_head_dims, head_dim) + if minimum_larger_index >= len(self.flash_attn_valid_head_dims): + head_dim_padding = -1 # head_dim is bigger than the largest one, we cannot do padding + else: + head_dim_padding = self.flash_attn_valid_head_dims[minimum_larger_index] - head_dim + # Head dimension is checked in Attention.set_use_memory_efficient_attention_xformers. We maybe pad on head_dim. + if head_dim_padding > 0: + query_padded = mint.nn.functional.pad(query, (0, head_dim_padding), mode="constant", value=0.0) + key_padded = mint.nn.functional.pad(key, (0, head_dim_padding), mode="constant", value=0.0) + value_padded = mint.nn.functional.pad(value, (0, head_dim_padding), mode="constant", value=0.0) + else: + query_padded, key_padded, value_padded = query, key, value + flash_attn = ops.operations.nn_ops.FlashAttentionScore( + scale_value=scale, head_num=heads, input_layout=input_layout, keep_prob=1 - attention_dropout + ) + if attention_mask is not None: + # flip mask, since ms FA treats 1 as discard, 0 as retain. + attention_mask = ~attention_mask if attention_mask.dtype == ms.bool_ else 1 - attention_mask + # (b, 1, 1, k_n) - > (b, 1, q_n, k_n), manual broadcast + if attention_mask.shape[-2] == 1: + attention_mask = mint.tile(attention_mask.bool(), (1, 1, query_tokens, 1)) + attention_mask = attention_mask.to(self.fa_mask_dtype) + + if input_layout == "BNSD": + # (b s n d) -> (b n s d) + query_padded = query_padded.swapaxes(1, 2) + key_padded = key_padded.swapaxes(1, 2) + value_padded = value_padded.swapaxes(1, 2) + elif input_layout == "BSH": + query_padded = query_padded.view(Bs, query_tokens, -1) + key_padded = key_padded.view(Bs, key_tokens, -1) + value_padded = value_padded.view(Bs, key_tokens, -1) + hidden_states_padded = flash_attn( + query_padded.to(self.FA_dtype), + key_padded.to(self.FA_dtype), + value_padded.to(self.FA_dtype), + None, + None, + None, + attention_mask, + )[3] + # If we did padding before calculate attention, undo it! + if head_dim_padding > 0: + if input_layout == "BNSD": + hidden_states = hidden_states_padded[..., :head_dim] + else: + hidden_states = hidden_states_padded.view(Bs, query_tokens, heads, -1)[..., :head_dim] + hidden_states = hidden_states.view(Bs, query_tokens, -1) + else: + hidden_states = hidden_states_padded + if input_layout == "BNSD": + # b n s d -> b s n d + hidden_states = hidden_states.swapaxes(1, 2) + hidden_states = hidden_states.reshape(Bs, query_tokens, -1) + hidden_states = hidden_states.to(query.dtype) + return hidden_states + + def scaled_dot_product_attention( + self, + query, + key, + value, + input_layout, + head_num=None, + attention_mask=None, + scale=None, + dropout_p=0.0, + is_causal=False, + ) -> ms.Tensor: + def trans_tensor_shape(x, layout, head_num): + if layout == "BSH": + batch = x.shape[0] + x = x.view(batch, -1, head_num, x.shape[-1] // head_num).swapaxes(1, 2) + elif layout == "SBH": + batch = x.shape[1] + x = x.view(-1, batch * head_num, x.shape[-1] // head_num).swapaxes(0, 1) + x = x.view(batch, head_num, -1, x.shape[-1]) + return x + + query = trans_tensor_shape(query, input_layout, head_num) + key = trans_tensor_shape(key, input_layout, head_num) + value = trans_tensor_shape(value, input_layout, head_num) + + attn_weight = query @ key.swapaxes(-2, -1) * scale + attn_bias = mint.zeros_like(attn_weight, dtype=query.dtype) + if is_causal: + assert attention_mask is None + temp_mask = mint.zeros_like(attn_weight, dtype=ms.bool).tril(diagonal=0) + attn_bias.masked_fill(~temp_mask, npu_config.inf_float) + attn_bias.to(query.dtype) + + if attention_mask is not None: + assert ( + not self.enable_FA + ) and attention_mask.dtype != ms.bool, "attention_mask must not be bool type when use this function" + + attn_weight += attn_bias + attn_weight = mint.nn.functional.softmax(attn_weight, dim=-1) + attn_weight = mint.nn.functional.dropout(attn_weight, p=dropout_p, training=True) + output = attn_weight @ value + if input_layout == "BSH": + output = output.swapaxes(1, 2).view(output.shape[0], -1, head_num * output.shape[-1]) + else: + output = output.view(output.shape[0] * head_num, -1, output.shape[-1]).swapaxes(0, 1) + output = output.view(output.shape[0], -1, head_num * output.shape[-1]) + return output + + +npu_config = NPUConfig() diff --git a/examples/opensora_pku/opensora/sample/caption_refiner.py b/examples/opensora_pku/opensora/sample/caption_refiner.py new file mode 100644 index 0000000000..a889a90169 --- /dev/null +++ b/examples/opensora_pku/opensora/sample/caption_refiner.py @@ -0,0 +1,36 @@ +from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer + +import mindspore as ms + +TEMPLATE = """ +Refine the sentence: \"{}\" to contain subject description, action, scene description. " \ +"(Optional: camera language, light and shadow, atmosphere) and conceive some additional actions to make the sentence more dynamic. " \ +"Make sure it is a fluent sentence, not nonsense. +""" + + +class OpenSoraCaptionRefiner: + def __init__(self, caption_refiner, dtype): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(caption_refiner, ms_dtype=dtype) + self.model = AutoModelForCausalLM.from_pretrained(caption_refiner, ms_dtype=dtype) + + def get_refiner_output(self, prompt): + prompt = TEMPLATE.format(prompt) + messages = [{"role": "system", "content": "You are a caption refiner."}, {"role": "user", "content": prompt}] + input_ids = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.tokenizer([input_ids], return_tensors="np") + generated_ids = self.model.generate(ms.Tensor(model_inputs.input_ids), max_new_tokens=512) + generated_ids = [ + output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + return response + + +if __name__ == "__main__": + pretrained_model_name_or_path = "" + caption_refiner = OpenSoraCaptionRefiner(pretrained_model_name_or_path, dtype=ms.float16) + prompt = "a video of a girl playing in the park" + response = caption_refiner.get_refiner_output(prompt) + print(response) diff --git a/examples/opensora_pku/opensora/sample/rec_image.py b/examples/opensora_pku/opensora/sample/rec_image.py new file mode 100644 index 0000000000..36191269e8 --- /dev/null +++ b/examples/opensora_pku/opensora/sample/rec_image.py @@ -0,0 +1,164 @@ +import argparse +import logging +import os +import sys + +import cv2 +import numpy as np +from albumentations import Compose, Lambda, Resize, ToFloat +from PIL import Image + +import mindspore as ms + +mindone_lib_path = os.path.abspath("../../") +sys.path.insert(0, mindone_lib_path) + +from mindone.utils.config import str2bool +from mindone.utils.logger import set_logger + +sys.path.append(".") + +from opensora.models.causalvideovae import ae_wrapper +from opensora.npu_config import npu_config +from opensora.utils.utils import get_precision + +logger = logging.getLogger(__name__) + + +def create_transform(max_height, max_width): + norm_fun = lambda x: 2.0 * x - 1.0 + + def norm_func_albumentation(image, **kwargs): + return norm_fun(image) + + mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} + resize = [ + Resize(max_height, max_width, interpolation=mapping["bilinear"]), + ] + + transform = Compose( + [*resize, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)], + ) + return transform + + +def preprocess(image, height: int = 128, width: int = 128): + video_transform = create_transform(height, width) + + image = video_transform(image=image)["image"] # (h w c) + # (h w c) -> (c h w) -> (c t h w) + image = np.transpose(image, (2, 0, 1))[:, None, :, :] + return image + + +def transform_to_rgb(x, rescale_to_uint8=True): + x = np.clip(x, -1, 1) + x = (x + 1) / 2 + if rescale_to_uint8: + x = (255 * x).astype(np.uint8) + return x + + +def main(args): + image_path = args.image_path + short_size = args.short_size + npu_config.set_npu_env(args) + + set_logger(name="", output_dir=args.output_path, rank=0) + dtype = get_precision(args.precision) + if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): + logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") + state_dict = ms.load_checkpoint(args.ms_checkpoint) + # rm 'network.' prefix + state_dict = dict( + [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() + ) + else: + state_dict = None + kwarg = { + "state_dict": state_dict, + "use_safetensors": True, + "dtype": dtype, + } + vae = ae_wrapper[args.ae](args.ae_path, **kwarg) + + if args.enable_tiling: + vae.vae.enable_tiling() + vae.vae.tile_overlap_factor = args.tile_overlap_factor + + vae.set_train(False) + for param in vae.get_parameters(): + param.requires_grad = False + + input_x = np.array(Image.open(image_path)) # (h w c) + assert input_x.shape[2], f"Expect the input image has three channels, but got shape {input_x.shape}" + x_vae = preprocess(input_x, short_size, short_size) # use image as a single-frame video + + x_vae = ms.Tensor(x_vae, dtype).unsqueeze(0) # b c t h w + latents = vae.encode(x_vae) + latents = latents.to(dtype) + image_recon = vae.decode(latents) # b t c h w + + save_fp = os.path.join(args.output_path, args.rec_path) + x = image_recon[0, 0, :, :, :] + x = x.squeeze().asnumpy() + x = transform_to_rgb(x) + x = x.transpose(1, 2, 0) + if args.grid: + x = np.concatenate([input_x, x], axis=1) + image = Image.fromarray(x) + image.save(save_fp) + if args.grid: + logger.info(f"Save original vs. reconstructed data to {save_fp}") + else: + logger.info(f"Save reconstructed data to {save_fp}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--image_path", type=str, default="") + parser.add_argument("--rec_path", type=str, default="") + parser.add_argument("--ae", type=str, default="WFVAEModel_D8_4x8x8", choices=ae_wrapper.keys()) + parser.add_argument("--ae_path", type=str, default="results/pretrained") + parser.add_argument("--ms_checkpoint", type=str, default=None) + parser.add_argument("--short_size", type=int, default=336) + parser.add_argument("--tile_overlap_factor", type=float, default=0.25) + parser.add_argument("--tile_sample_min_size", type=int, default=256) + parser.add_argument("--enable_tiling", action="store_true") + # ms related + parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument( + "--precision", + default="bf16", + type=str, + choices=["fp32", "fp16", "bf16"], + help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \ + if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.", + ) + parser.add_argument("--device", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument( + "--precision_mode", + default=None, + type=str, + help="If specified, set the precision mode for Ascend configurations.", + ) + parser.add_argument( + "--vae_keep_gn_fp32", + default=False, + type=str2bool, + help="whether keep GroupNorm in fp32. Defaults to False in inference, better to set to True when training vae", + ) + parser.add_argument( + "--output_path", default="samples/vae_recons", type=str, help="output directory to save inference results" + ) + parser.add_argument( + "--grid", + action="store_true", + help="whether to use grid to show original and reconstructed data", + ) + parser.add_argument("--jit_level", default="O0", help="Set jit level: # O0: KBK, O1:DVM, O2: GE") + parser.add_argument( + "--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax" + ) + args = parser.parse_args() + main(args) diff --git a/examples/opensora_pku/opensora/sample/rec_video.py b/examples/opensora_pku/opensora/sample/rec_video.py new file mode 100644 index 0000000000..f5463b4d9c --- /dev/null +++ b/examples/opensora_pku/opensora/sample/rec_video.py @@ -0,0 +1,221 @@ +""" +Run causal vae reconstruction on a given video. +Usage example: +python examples/rec_video.py \ + --ae_path path/to/vae/ckpt \ + --video_path test.mp4 \ + --rec_path rec.mp4 \ + --sample_rate 1 \ + --num_frames 65 \ + --height 480 \ + --width 640 \ +""" +import argparse +import logging +import os +import random +import sys + +import numpy as np +from decord import VideoReader, cpu +from PIL import Image + +import mindspore as ms + +mindone_lib_path = os.path.abspath("../../") +sys.path.insert(0, mindone_lib_path) +from mindone.utils.logger import set_logger +from mindone.visualize.videos import save_videos + +sys.path.append(".") +from functools import partial + +import cv2 +from albumentations import Compose, Lambda, Resize, ToFloat +from opensora.dataset.transform import center_crop_th_tw +from opensora.models.causalvideovae import ae_wrapper +from opensora.npu_config import npu_config +from opensora.utils.utils import get_precision + +logger = logging.getLogger(__name__) + + +def read_video(video_path: str, num_frames: int, sample_rate: int) -> ms.Tensor: + decord_vr = VideoReader(video_path, ctx=cpu(0)) + total_frames = len(decord_vr) + sample_frames_len = sample_rate * num_frames + + if total_frames > sample_frames_len: + s = random.randint(0, total_frames - sample_frames_len - 1) + s = 0 + e = s + sample_frames_len + num_frames = num_frames + else: + s = 0 + e = total_frames + num_frames = int(total_frames / sample_frames_len * num_frames) + print( + f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}", + video_path, + total_frames, + ) + + frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) + video_data = decord_vr.get_batch(frame_id_list).asnumpy() + return video_data + + +def create_transform(max_height, max_width, num_frames): + norm_fun = lambda x: 2.0 * x - 1.0 + + def norm_func_albumentation(image, **kwargs): + return norm_fun(image) + + mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} + targets = {"image{}".format(i): "image" for i in range(num_frames)} + resize = [ + Lambda( + name="crop_centercrop", + image=partial(center_crop_th_tw, th=max_height, tw=max_width, top_crop=False), + p=1.0, + ), + Resize(max_height, max_width, interpolation=mapping["bilinear"]), + ] + + transform = Compose( + [*resize, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)], + additional_targets=targets, + ) + return transform + + +def preprocess(video_data, height: int = 128, width: int = 128): + num_frames = video_data.shape[0] + video_transform = create_transform(height, width, num_frames=num_frames) + + inputs = {"image": video_data[0]} + for i in range(num_frames - 1): + inputs[f"image{i}"] = video_data[i + 1] + + video_outputs = video_transform(**inputs) + video_outputs = np.stack(list(video_outputs.values()), axis=0) # (t h w c) + # (t h w c) -> (c t h w) + video_outputs = np.transpose(video_outputs, (3, 0, 1, 2)) + return video_outputs + + +def transform_to_rgb(x, rescale_to_uint8=True): + x = np.clip(x, -1, 1) + x = (x + 1) / 2 + if rescale_to_uint8: + x = (255 * x).astype(np.uint8) + return x + + +def main(args): + npu_config.set_npu_env(args) + dtype = get_precision(args.precision) + set_logger(name="", output_dir=args.output_path, rank=0) + if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): + logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") + state_dict = ms.load_checkpoint(args.ms_checkpoint) + # rm 'network.' prefix + state_dict = dict( + [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() + ) + else: + state_dict = None + kwarg = { + "state_dict": state_dict, + "use_safetensors": True, + "dtype": dtype, + } + vae = ae_wrapper[args.ae](args.ae_path, **kwarg) + + if args.enable_tiling: + vae.vae.enable_tiling() + vae.vae.tile_overlap_factor = args.tile_overlap_factor + + vae.set_train(False) + for param in vae.get_parameters(): + param.requires_grad = False + + x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.height, args.width) + + x_vae = ms.Tensor(x_vae, dtype).unsqueeze(0) # b c t h w + latents = vae.encode(x_vae) + latents = latents.to(dtype) + video_recon = vae.decode(latents) # b t c h w + + save_fp = os.path.join(args.output_path, args.rec_path) + if ".avi" in os.path.basename(save_fp): + save_fp = save_fp.replace(".avi", ".mp4") + if video_recon.shape[1] == 1: + x = video_recon[0, 0, :, :, :].squeeze().to(ms.float32).asnumpy() + original_rgb = x_vae[0, 0, :, :, :].squeeze().to(ms.float32).asnumpy() + x = transform_to_rgb(x).transpose(1, 2, 0) # c h w -> h w c + original_rgb = transform_to_rgb(original_rgb).transpose(1, 2, 0) # c h w -> h w c + + image = Image.fromarray(np.concatenate([x, original_rgb], axis=1) if args.grid else x) + save_fp = save_fp.replace("mp4", "jpg") + image.save(save_fp) + else: + save_video_data = video_recon.transpose(0, 1, 3, 4, 2).to(ms.float32).asnumpy() # (b t c h w) -> (b t h w c) + save_video_data = transform_to_rgb(save_video_data, rescale_to_uint8=False) + original_rgb = transform_to_rgb(x_vae.to(ms.float32).asnumpy(), rescale_to_uint8=False).transpose( + 0, 2, 3, 4, 1 + ) # (b c t h w) -> (b t h w c) + save_video_data = np.concatenate([original_rgb, save_video_data], axis=3) if args.grid else save_video_data + save_videos(save_video_data, save_fp, loop=0, fps=args.fps) + if args.grid: + logger.info(f"Save original vs. reconstructed data to {save_fp}") + else: + logger.info(f"Save reconstructed data to {save_fp}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--video_path", type=str, default="") + parser.add_argument("--rec_path", type=str, default="") + parser.add_argument("--ae", type=str, default="") + parser.add_argument("--ae_path", type=str, default="results/pretrained") + parser.add_argument("--ms_checkpoint", type=str, default=None) + parser.add_argument("--fps", type=int, default=30) + parser.add_argument("--height", type=int, default=336) + parser.add_argument("--width", type=int, default=336) + parser.add_argument("--num_frames", type=int, default=65) + parser.add_argument("--sample_rate", type=int, default=1) + parser.add_argument("--enable_tiling", action="store_true") + parser.add_argument("--tile_overlap_factor", type=float, default=0.25) + # ms related + parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument( + "--precision", + default="bf16", + type=str, + choices=["fp32", "fp16", "bf16"], + help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \ + if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.", + ) + + parser.add_argument("--device", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument( + "--precision_mode", + default=None, + type=str, + help="If specified, set the precision mode for Ascend configurations.", + ) + parser.add_argument( + "--output_path", default="samples/vae_recons", type=str, help="output directory to save inference results" + ) + parser.add_argument( + "--grid", + action="store_true", + help="whether to use grid to show original and reconstructed data", + ) + parser.add_argument("--jit_level", default="O0", help="Set jit level: # O0: KBK, O1:DVM, O2: GE") + parser.add_argument( + "--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax" + ) + args = parser.parse_args() + main(args) diff --git a/examples/opensora_pku/opensora/train/commons.py b/examples/opensora_pku/opensora/train/commons.py index 3ad629baa7..17650e2051 100644 --- a/examples/opensora_pku/opensora/train/commons.py +++ b/examples/opensora_pku/opensora/train/commons.py @@ -34,7 +34,7 @@ def parse_train_args(parser): ################################################################################# parser.add_argument("--device", type=str, default="Ascend", help="Ascend or GPU") parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") - parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") parser.add_argument( "--jit_syntax_level", default="strict", type=str, help="Specify syntax level for graph mode: strict or lax" ) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index c7e5a9a2b0..d967baf61c 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -1,7 +1,6 @@ """ Train AutoEncoders with GAN loss """ -import json import logging import math import os @@ -12,19 +11,19 @@ import yaml import mindspore as ms -from mindspore import Model, nn +from mindspore import Model from mindspore.train.callback import TimeMonitor sys.path.append(".") mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) -from opensora.models.causalvideovae.model import EMA, CausalVAEModel from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader +from opensora.models.causalvideovae.model.ema_model import EMA from opensora.models.causalvideovae.model.losses.net_with_loss import DiscriminatorWithLoss, GeneratorWithLoss -from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate +from opensora.models.causalvideovae.model.registry import ModelRegistry from opensora.models.causalvideovae.model.utils.model_utils import resolve_str_to_obj +from opensora.npu_config import npu_config from opensora.train.commons import create_loss_scaler, parse_args -from opensora.utils.ms_utils import init_env from opensora.utils.utils import get_precision from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback @@ -32,7 +31,6 @@ from mindone.trainers.lr_schedule import create_scheduler from mindone.trainers.optim import create_optimizer from mindone.trainers.train_step import TrainOneStepWrapper -from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import str2bool from mindone.utils.logger import set_logger from mindone.utils.params import count_params @@ -42,24 +40,33 @@ def main(args): # 1. init - rank_id, device_num = init_env( - args.mode, - seed=args.seed, - distributed=args.use_parallel, - device_target=args.device, - max_device_memory=args.max_device_memory, - parallel_mode=args.parallel_mode, - jit_level=args.jit_level, - jit_syntax_level=args.jit_syntax_level, - ) + rank_id, device_num = npu_config.set_npu_env(args) + dtype = get_precision(args.precision) + if args.exp_name is not None and len(args.exp_name) > 0: args.output_dir = os.path.join(args.output_dir, args.exp_name) set_logger(name="", output_dir=args.output_dir, rank=rank_id, log_level=eval(args.log_level)) # Load Config - assert os.path.exists(args.model_config), f"{args.model_config} does not exist!" - model_config = json.load(open(args.model_config, "r")) - ae = CausalVAEModel.from_config(model_config, use_recompute=args.use_recompute) + model_cls = ModelRegistry.get_model(args.model_name) + + if not model_cls: + raise ModuleNotFoundError(f"`{args.model_name}` not in {str(ModelRegistry._models.keys())}.") + if args.pretrained_model_name_or_path is not None: + if rank_id == 0: + logger.warning(f"You are loading a checkpoint from `{args.pretrained_model_name_or_path}`.") + ae = model_cls.from_pretrained( + args.pretrained_model_name_or_path, + ignore_mismatched_sizes=args.ignore_mismatched_sizes, + low_cpu_mem_usage=False, + device_map=None, + dtype=dtype, + ) + else: + if rank_id == 0: + logger.warning(f"Model will be initialized from config file {args.model_config}.") + ae = model_cls.from_config(args.model_config, dtype=dtype) + if args.load_from_checkpoint is not None: ae.init_from_ckpt(args.load_from_checkpoint) # discriminator (D) @@ -76,34 +83,14 @@ def main(args): elif "LPIPSWithDiscriminator" in disc_type: disc_type = "opensora.models.causalvideovae.model.losses.discriminator.NLayerDiscriminator" use_3d_disc = False - disc = resolve_str_to_obj(disc_type, append=False)() + disc = resolve_str_to_obj(disc_type, append=False)(dtype=dtype) else: disc = None - # mixed precision - # TODO: set softmax, sigmoid computed in FP32. manually set inside network since they are ops, instead of layers whose precision will be set by AMP level. - if args.precision in ["fp16", "bf16"]: - amp_level = args.amp_level - dtype = get_precision(args.precision) - if dtype == ms.float16: - custom_fp32_cells = [nn.GroupNorm, nn.Softmax, nn.SiLU] if args.vae_keep_gn_fp32 else [nn.Softmax, nn.SiLU] - else: - custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate, nn.Softmax, nn.SiLU] - ae = auto_mixed_precision(ae, amp_level=amp_level, dtype=dtype, custom_fp32_cells=custom_fp32_cells) - logger.info( - f"Use amp level {amp_level} for causal 3D VAE with dtype={dtype}, custom_fp32_cells {custom_fp32_cells}" - ) - - if use_discriminator: - disc = auto_mixed_precision(disc, amp_level, dtype) - logger.info(f"Use amp level {amp_level} for discriminator with dtype={dtype}") - elif args.precision == "fp32": - amp_level = "O0" - else: - raise ValueError(f"Unsupported precision {args.precision}") - # 3. build net with loss (core) # G with loss + if args.wavelet_loss: + logger.warning("wavelet_loss is not implemented, and will be ignored.") ae_with_loss = GeneratorWithLoss( ae, discriminator=disc, @@ -114,6 +101,7 @@ def main(args): logvar_init=args.logvar_init, perceptual_weight=args.perceptual_weight, loss_type=args.loss_type, + wavelet_weight=args.wavelet_weight, ) disc_start = args.disc_start @@ -343,12 +331,11 @@ def main(args): key_info = "Key Settings:\n" + "=" * 50 + "\n" key_info += "\n".join( [ - f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", - f"Jit level: {args.jit_level}", + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}" + + (f"\nJit level: {args.jit_level}" if args.mode == 0 else ""), f"Distributed mode: {args.use_parallel}", - f"Recompute: {args.use_recompute}", - f"amp level: {amp_level}", f"dtype: {args.precision}", + f"Optimizer: {args.optim}", f"Use discriminator: {args.use_discriminator}", f"Learning rate: {learning_rate}", f"Batch size: {args.train_batch_size}", @@ -393,7 +380,7 @@ def main(args): ckpt_save_interval=ckpt_save_interval, log_interval=args.log_interval, start_epoch=start_epoch, - model_name="vae_3d", + model_name=args.model_name, record_lr=False, save_training_resume=args.save_training_resume, ) @@ -565,6 +552,12 @@ def parse_causalvae_train_args(parser): default="scripts/causalvae/release.json", help="the default model configuration file for the causalvae.", ) + parser.add_argument( + "--model_name", + default="", + help="the default model name for the causalvae.", + ) + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, help="") parser.add_argument( "--vae_keep_gn_fp32", default=True, @@ -625,6 +618,8 @@ def parse_causalvae_train_args(parser): parser.add_argument("--perceptual_weight", type=float, default=1.0, help="") parser.add_argument("--loss_type", type=str, default="l1", help="") parser.add_argument("--logvar_init", type=float, default=0.0, help="") + parser.add_argument("--wavelet_loss", action="store_true", help="") + parser.add_argument("--wavelet_weight", type=float, default=0.1, help="") return parser diff --git a/examples/opensora_pku/requirements.txt b/examples/opensora_pku/requirements.txt index 95d1af1f22..ca328d81ac 100644 --- a/examples/opensora_pku/requirements.txt +++ b/examples/opensora_pku/requirements.txt @@ -14,9 +14,8 @@ safetensors omegaconf pyyaml sentencepiece -av -beautifulsoup4 -huggingface_hub>=0.22.2,<0.26 -transformers -tokenizers -pillow +mindnlp==0.4.0 +transformers>=4.46.0 +pyav +bs4 +huggingface_hub>=0.22.2 diff --git a/examples/opensora_pku/scripts/causalvae/rec_image.sh b/examples/opensora_pku/scripts/causalvae/rec_image.sh index 4f207e4c78..38fc720683 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_image.sh +++ b/examples/opensora_pku/scripts/causalvae/rec_image.sh @@ -1,7 +1,8 @@ python examples/rec_image.py \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae \ - --image_path /storage/dataset/image/anytext3m/ocr_data/Art/images/gt_5544.jpg \ + --ae "WFVAEModel_D8_4x8x8" \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --image_path example.jpg \ --rec_path rec.jpg \ --device Ascend \ --short_size 512 \ - --enable_tiling + --mode 1 \ diff --git a/examples/opensora_pku/scripts/causalvae/rec_video.sh b/examples/opensora_pku/scripts/causalvae/rec_video.sh index 9cfeff45fd..4a9716c28b 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_video.sh +++ b/examples/opensora_pku/scripts/causalvae/rec_video.sh @@ -1,12 +1,13 @@ python examples/rec_video.py \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae \ + --ae "WFVAEModel_D8_4x8x8" \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --video_path test.mp4 \ --rec_path rec.mp4 \ --device Ascend \ --sample_rate 1 \ - --num_frames 65 \ - --height 480 \ - --width 640 \ + --num_frames 61 \ + --height 512 \ + --width 512 \ + --fps 30 \ --enable_tiling \ - --tile_overlap_factor 0.125 \ - --save_memory + --mode 1 \ diff --git a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh index 5c39cbc15f..a52d7c3332 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh +++ b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh @@ -9,7 +9,7 @@ python examples/rec_video_folder.py \ --height 480 \ --width 640 \ --num_workers 8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae \ + --ae "WFVAEModel_D8_4x8x8" \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --enable_tiling \ --tile_overlap_factor 0.125 \ - --save_memory diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index 88a9997ee6..764e12b1f3 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -1,5 +1,7 @@ python opensora/train/train_causalvae.py \ --exp_name "25x256x256" \ + --model_name WFVAE \ + --model_config scripts/causalvae/wfvae_4dim.json \ --train_batch_size 1 \ --precision fp32 \ --max_steps 100000 \ @@ -9,24 +11,24 @@ python opensora/train/train_causalvae.py \ --data_file_path datasets/ucf101_train.csv \ --video_num_frames 25 \ --resolution 256 \ - --sample_rate 2 \ --dataloader_num_workers 8 \ --load_from_checkpoint pretrained/causal_vae_488_init.ckpt \ --start_learning_rate 1e-5 \ --lr_scheduler constant \ - --optim adam \ + --optim adamw \ --betas 0.9 0.999 \ --clip_grad True \ --weight_decay 0.0 \ - --mode 0 \ + --mode 1 \ --init_loss_scale 65536 \ --jit_level "O0" \ --use_discriminator True \ --use_ema True\ - --ema_start_step 0 \ --ema_decay 0.999 \ --perceptual_weight 1.0 \ --loss_type l1 \ + --sample_rate 1 \ --disc_cls causalvideovae.model.losses.LPIPSWithDiscriminator3D \ - --disc_start 2000 \ - --use_recompute True \ + --disc_start 0 \ + --wavelet_loss \ + --wavelet_weight 0.1 diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index 918131ef1c..9adfcba876 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -21,11 +21,11 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --load_from_checkpoint pretrained/causal_vae_488_init.ckpt \ --start_learning_rate 1e-5 \ --lr_scheduler constant \ - --optim adam \ + --optim adamw \ --betas 0.9 0.999 \ --clip_grad True \ --weight_decay 0.0 \ - --mode 0 \ + --mode 1 \ --init_loss_scale 65536 \ --jit_level "O0" \ --use_discriminator True \ diff --git a/examples/opensora_pku/scripts/causalvae/wfvae_4dim.json b/examples/opensora_pku/scripts/causalvae/wfvae_4dim.json new file mode 100644 index 0000000000..6509a76d4f --- /dev/null +++ b/examples/opensora_pku/scripts/causalvae/wfvae_4dim.json @@ -0,0 +1,23 @@ +{ + "_class_name": "WFVAEModel", + "_diffusers_version": "0.30.2", + "base_channels": 128, + "connect_res_layer_num": 1, + "decoder_energy_flow_hidden_size": 128, + "decoder_num_resblocks": 2, + "dropout": 0.0, + "encoder_energy_flow_hidden_size": 128, + "encoder_num_resblocks": 2, + "l1_dowmsample_block": "Downsample", + "l1_downsample_wavelet": "HaarWaveletTransform2D", + "l1_upsample_block": "Upsample", + "l1_upsample_wavelet": "InverseHaarWaveletTransform2D", + "l2_dowmsample_block": "Spatial2xTime2x3DDownsample", + "l2_downsample_wavelet": "HaarWaveletTransform3D", + "l2_upsample_block": "Spatial2xTime2x3DUpsample", + "l2_upsample_wavelet": "InverseHaarWaveletTransform3D", + "latent_dim": 4, + "norm_type": "layernorm", + "t_interpolation": "trilinear", + "use_attention": true +} diff --git a/examples/opensora_pku/tests/test_wavelet.py b/examples/opensora_pku/tests/test_wavelet.py new file mode 100644 index 0000000000..2afa5fe75c --- /dev/null +++ b/examples/opensora_pku/tests/test_wavelet.py @@ -0,0 +1,71 @@ +import os +import sys +import unittest + +import numpy as np +import torch + +from mindspore import Tensor + +sys.path.insert(0, os.path.abspath("./")) +sys.path.insert(0, os.path.abspath("../../")) +from opensora.models.causalvideovae.model.modules.wavelet import ( + HaarWaveletTransform2D, + HaarWaveletTransform3D, + InverseHaarWaveletTransform2D, + InverseHaarWaveletTransform3D, +) +from tests.torch_wavelet import HaarWaveletTransform2D as HaarWaveletTransform2D_torch +from tests.torch_wavelet import HaarWaveletTransform3D as HaarWaveletTransform3D_torch +from tests.torch_wavelet import InverseHaarWaveletTransform2D as InverseHaarWaveletTransform2D_torch +from tests.torch_wavelet import InverseHaarWaveletTransform3D as InverseHaarWaveletTransform3D_torch + +sys.path.append(".") +import mindspore as ms + +dtype = ms.float16 + + +class TestWaveletTransforms(unittest.TestCase): + def setUp(self): + # Initialize all modules + self.modules = { + "HaarWaveletTransform2D": [HaarWaveletTransform2D(), HaarWaveletTransform2D_torch()], + "HaarWaveletTransform3D": [HaarWaveletTransform3D(), HaarWaveletTransform3D_torch()], + "InverseHaarWaveletTransform3D": [InverseHaarWaveletTransform3D(), InverseHaarWaveletTransform3D_torch()], + "InverseHaarWaveletTransform2D": [InverseHaarWaveletTransform2D(), InverseHaarWaveletTransform2D_torch()], + } + + def generate_input(self, module_name): + # Define input shapes based on module name + input_shapes = { + "HaarWaveletTransform2D": (1, 1, 6, 6), # Example shape for 2D + "HaarWaveletTransform3D": (1, 1, 6, 6, 6), # Example shape for 3D + "InverseHaarWaveletTransform3D": (1, 8, 6, 6, 6), # Example shape for 3D + "InverseHaarWaveletTransform2D": (1, 1, 6, 6), # Example shape for 2D + } + shape = input_shapes[module_name] + return torch.randn(*shape) + + def test_output_similarity(self): + for module_name, module in self.modules.items(): + with self.subTest(module=module_name): + x_torch = self.generate_input(module_name) + x_mindspore = Tensor(x_torch.numpy()) + module_ms, module_torch = module + + output_torch = module_torch(x_torch) + output_mindspore = module_ms(x_mindspore.to(dtype)) + + abs_diff = np.abs(output_torch.numpy() - output_mindspore.asnumpy()) + print(f"Mean Absolute Difference for {module_name}: {abs_diff.mean()}") + print(f"Relative Abs Difference for {module_name}: {np.mean(abs_diff/(output_torch.numpy()+1e-6))}") + + self.assertTrue( + np.allclose(output_torch.numpy(), output_mindspore.asnumpy(), atol=1e-5), + f"Outputs of {module_name} are not close enough.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/opensora_pku/tests/torch_wavelet.py b/examples/opensora_pku/tests/torch_wavelet.py new file mode 100644 index 0000000000..cc912e317a --- /dev/null +++ b/examples/opensora_pku/tests/torch_wavelet.py @@ -0,0 +1,327 @@ +torch_npu = None +npu_config = None + +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length) + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int, int, int]], + enable_cached=False, + bias=True, + **kwargs, + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = self.kernel_size[0] + self.chan_in = chan_in + self.chan_out = chan_out + self.stride = kwargs.pop("stride", 1) + self.padding = kwargs.pop("padding", 0) + self.padding = list(cast_tuple(self.padding, 3)) + self.padding[0] = 0 + self.stride = cast_tuple(self.stride, 3) + self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=self.stride, padding=self.padding, bias=bias) + self.enable_cached = enable_cached + self.causal_cached = None + self.cache_offset = 0 + + def forward(self, x): + x_dtype = x.dtype + if self.causal_cached is None: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + else: + first_frame_pad = self.causal_cached + x = torch.concatenate((first_frame_pad, x), dim=2) + + if self.enable_cached and self.time_kernel_size != 1: + if (self.time_kernel_size - 1) // self.stride[0] != 0: + if self.cache_offset == 0: + self.causal_cached = x[:, :, -(self.time_kernel_size - 1) // self.stride[0] :] + else: + self.causal_cached = x[:, :, : -self.cache_offset][ + :, :, -(self.time_kernel_size - 1) // self.stride[0] : + ] + else: + self.causal_cached = x[:, :, 0:0, :, :] + + if npu_config is not None and npu_config.on_npu: + return npu_config.run_conv3d(self.conv, x, x_dtype) + else: + x = self.conv(x) + return x + + +def video_to_image(func): + def wrapper(self, x, *args, **kwargs): + if x.dim() == 5: + t = x.shape[2] + if True: + x = rearrange(x, "b c t h w -> (b t) c h w") + x = func(self, x, *args, **kwargs) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + else: + # Conv 2d slice infer + result = [] + for i in range(t): + frame = x[:, :, i, :, :] + frame = func(self, frame, *args, **kwargs) + result.append(frame.unsqueeze(2)) + x = torch.concatenate(result, dim=2) + return x + + return wrapper + + +class HaarWaveletTransform3D(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + h = torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) * 0.3536 + g = torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]) * 0.3536 + hh = torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]) * 0.3536 + gh = torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]) * 0.3536 + h_v = torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]) * 0.3536 + g_v = torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]) * 0.3536 + hh_v = torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]) * 0.3536 + gh_v = torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]) * 0.3536 + h = h.view(1, 1, 2, 2, 2) + g = g.view(1, 1, 2, 2, 2) + hh = hh.view(1, 1, 2, 2, 2) + gh = gh.view(1, 1, 2, 2, 2) + h_v = h_v.view(1, 1, 2, 2, 2) + g_v = g_v.view(1, 1, 2, 2, 2) + hh_v = hh_v.view(1, 1, 2, 2, 2) + gh_v = gh_v.view(1, 1, 2, 2, 2) + + self.h_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.g_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.hh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.gh_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.h_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.g_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.hh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + self.gh_v_conv = CausalConv3d(1, 1, 2, padding=0, stride=2, bias=False) + + self.h_conv.conv.weight.data = h + self.g_conv.conv.weight.data = g + self.hh_conv.conv.weight.data = hh + self.gh_conv.conv.weight.data = gh + self.h_v_conv.conv.weight.data = h_v + self.g_v_conv.conv.weight.data = g_v + self.hh_v_conv.conv.weight.data = hh_v + self.gh_v_conv.conv.weight.data = gh_v + self.h_conv.requires_grad_(False) + self.g_conv.requires_grad_(False) + self.hh_conv.requires_grad_(False) + self.gh_conv.requires_grad_(False) + self.h_v_conv.requires_grad_(False) + self.g_v_conv.requires_grad_(False) + self.hh_v_conv.requires_grad_(False) + self.gh_v_conv.requires_grad_(False) + + def forward(self, x): + assert x.dim() == 5 + + if torch_npu is not None: + dtype = x.dtype + x = x.to(npu_config.conv_dtype) + self.to(npu_config.conv_dtype) + + b = x.shape[0] + x = rearrange(x, "b c t h w -> (b c) 1 t h w") + low_low_low = self.h_conv(x) + low_low_low = rearrange(low_low_low, "(b c) 1 t h w -> b c t h w", b=b) + low_low_high = self.g_conv(x) + low_low_high = rearrange(low_low_high, "(b c) 1 t h w -> b c t h w", b=b) + low_high_low = self.hh_conv(x) + low_high_low = rearrange(low_high_low, "(b c) 1 t h w -> b c t h w", b=b) + low_high_high = self.gh_conv(x) + low_high_high = rearrange(low_high_high, "(b c) 1 t h w -> b c t h w", b=b) + high_low_low = self.h_v_conv(x) + high_low_low = rearrange(high_low_low, "(b c) 1 t h w -> b c t h w", b=b) + high_low_high = self.g_v_conv(x) + high_low_high = rearrange(high_low_high, "(b c) 1 t h w -> b c t h w", b=b) + high_high_low = self.hh_v_conv(x) + high_high_low = rearrange(high_high_low, "(b c) 1 t h w -> b c t h w", b=b) + high_high_high = self.gh_v_conv(x) + high_high_high = rearrange(high_high_high, "(b c) 1 t h w -> b c t h w", b=b) + + output = torch.cat( + [ + low_low_low, + low_low_high, + low_high_low, + low_high_high, + high_low_low, + high_low_high, + high_high_low, + high_high_high, + ], + dim=1, + ) + + if torch_npu is not None: + x = x.to(dtype) + output = output.to(dtype) + self.to(dtype) + + return output + + +class InverseHaarWaveletTransform3D(nn.Module): + def __init__(self, enable_cached=False, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.register_buffer("h", torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536) + self.register_buffer("g", torch.tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536) + self.register_buffer("hh", torch.tensor([[[1, 1], [-1, -1]], [[1, 1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536) + self.register_buffer("gh", torch.tensor([[[1, -1], [-1, 1]], [[1, -1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536) + self.register_buffer("h_v", torch.tensor([[[1, 1], [1, 1]], [[-1, -1], [-1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536) + self.register_buffer("g_v", torch.tensor([[[1, -1], [1, -1]], [[-1, 1], [-1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536) + self.register_buffer( + "hh_v", torch.tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536 + ) + self.register_buffer( + "gh_v", torch.tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]]).view(1, 1, 2, 2, 2) * 0.3536 + ) + self.enable_cached = enable_cached + self.causal_cached = None + + def forward(self, coeffs): + assert coeffs.dim() == 5 + + if torch_npu is not None: + dtype = coeffs.dtype + coeffs = coeffs.to(npu_config.conv_dtype) + self.h = self.h.to(npu_config.conv_dtype) + self.g = self.g.to(npu_config.conv_dtype) + self.hh = self.hh.to(npu_config.conv_dtype) + self.gh = self.gh.to(npu_config.conv_dtype) + self.h_v = self.h_v.to(npu_config.conv_dtype) + self.g_v = self.g_v.to(npu_config.conv_dtype) + self.hh_v = self.hh_v.to(npu_config.conv_dtype) + self.gh_v = self.gh_v.to(npu_config.conv_dtype) + + b = coeffs.shape[0] + + ( + low_low_low, + low_low_high, + low_high_low, + low_high_high, + high_low_low, + high_low_high, + high_high_low, + high_high_high, + ) = coeffs.chunk(8, dim=1) + + low_low_low = rearrange(low_low_low, "b c t h w -> (b c) 1 t h w") + low_low_high = rearrange(low_low_high, "b c t h w -> (b c) 1 t h w") + low_high_low = rearrange(low_high_low, "b c t h w -> (b c) 1 t h w") + low_high_high = rearrange(low_high_high, "b c t h w -> (b c) 1 t h w") + high_low_low = rearrange(high_low_low, "b c t h w -> (b c) 1 t h w") + high_low_high = rearrange(high_low_high, "b c t h w -> (b c) 1 t h w") + high_high_low = rearrange(high_high_low, "b c t h w -> (b c) 1 t h w") + high_high_high = rearrange(high_high_high, "b c t h w -> (b c) 1 t h w") + + low_low_low = F.conv_transpose3d(low_low_low, self.h, stride=2) + low_low_high = F.conv_transpose3d(low_low_high, self.g, stride=2) + low_high_low = F.conv_transpose3d(low_high_low, self.hh, stride=2) + low_high_high = F.conv_transpose3d(low_high_high, self.gh, stride=2) + high_low_low = F.conv_transpose3d(high_low_low, self.h_v, stride=2) + high_low_high = F.conv_transpose3d(high_low_high, self.g_v, stride=2) + high_high_low = F.conv_transpose3d(high_high_low, self.hh_v, stride=2) + high_high_high = F.conv_transpose3d(high_high_high, self.gh_v, stride=2) + if self.enable_cached and self.causal_cached: + reconstructed = ( + low_low_low + + low_low_high + + low_high_low + + low_high_high + + high_low_low + + high_low_high + + high_high_low + + high_high_high + ) + else: + reconstructed = ( + low_low_low[:, :, 1:] + + low_low_high[:, :, 1:] + + low_high_low[:, :, 1:] + + low_high_high[:, :, 1:] + + high_low_low[:, :, 1:] + + high_low_high[:, :, 1:] + + high_high_low[:, :, 1:] + + high_high_high[:, :, 1:] + ) + self.causal_cached = True + reconstructed = rearrange(reconstructed, "(b c) 1 t h w -> b c t h w", b=b) + + if torch_npu is not None: + coeffs = coeffs.to(dtype) + reconstructed = reconstructed.to(dtype) + self.h = self.h.to(dtype) + self.g = self.g.to(dtype) + self.hh = self.hh.to(dtype) + self.gh = self.gh.to(dtype) + self.h_v = self.h_v.to(dtype) + self.g_v = self.g_v.to(dtype) + self.hh_v = self.hh_v.to(dtype) + self.gh_v = self.gh_v.to(dtype) + + return reconstructed + + +class HaarWaveletTransform2D(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("aa", torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2) + self.register_buffer("ad", torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2) + self.register_buffer("da", torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2) + self.register_buffer("dd", torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2) + + @video_to_image + def forward(self, x): + b, c, h, w = x.shape + x = x.reshape(b * c, 1, h, w) + low_low = F.conv2d(x, self.aa, stride=2).reshape(b, c, h // 2, w // 2) + low_high = F.conv2d(x, self.ad, stride=2).reshape(b, c, h // 2, w // 2) + high_low = F.conv2d(x, self.da, stride=2).reshape(b, c, h // 2, w // 2) + high_high = F.conv2d(x, self.dd, stride=2).reshape(b, c, h // 2, w // 2) + coeffs = torch.cat([low_low, low_high, high_low, high_high], dim=1) + return coeffs + + +class InverseHaarWaveletTransform2D(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("aa", torch.tensor([[1, 1], [1, 1]]).view(1, 1, 2, 2) / 2) + self.register_buffer("ad", torch.tensor([[1, 1], [-1, -1]]).view(1, 1, 2, 2) / 2) + self.register_buffer("da", torch.tensor([[1, -1], [1, -1]]).view(1, 1, 2, 2) / 2) + self.register_buffer("dd", torch.tensor([[1, -1], [-1, 1]]).view(1, 1, 2, 2) / 2) + + @video_to_image + def forward(self, coeffs): + low_low, low_high, high_low, high_high = coeffs.chunk(4, dim=1) + b, c, height_half, width_half = low_low.shape + height = height_half * 2 + width = width_half * 2 + + low_low = F.conv_transpose2d(low_low.reshape(b * c, 1, height_half, width_half), self.aa, stride=2) + low_high = F.conv_transpose2d(low_high.reshape(b * c, 1, height_half, width_half), self.ad, stride=2) + high_low = F.conv_transpose2d(high_low.reshape(b * c, 1, height_half, width_half), self.da, stride=2) + high_high = F.conv_transpose2d(high_high.reshape(b * c, 1, height_half, width_half), self.dd, stride=2) + + return (low_low + low_high + high_low + high_high).reshape(b, c, height, width) diff --git a/examples/opensora_pku/tools/model_conversion/convert_pytorch_ckpt_to_safetensors.py b/examples/opensora_pku/tools/model_conversion/convert_pytorch_ckpt_to_safetensors.py index 69f986d68e..b28d150bab 100644 --- a/examples/opensora_pku/tools/model_conversion/convert_pytorch_ckpt_to_safetensors.py +++ b/examples/opensora_pku/tools/model_conversion/convert_pytorch_ckpt_to_safetensors.py @@ -31,9 +31,9 @@ def _remove_duplicate_names( complete_names = {name} else: raise RuntimeError( - f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}.\ + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst {shared}.\ None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. \ - Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + Please refer to https: //huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." ) keep_name = sorted(list(complete_names))[0] @@ -62,9 +62,7 @@ def check_file_size(sf_filename: str, pt_filename: str): if (sf_size - pt_size) / pt_size > 0.01: raise RuntimeError( - f"""The file size different is more than 1%: - - {sf_filename}: {sf_size} - - {pt_filename}: {pt_size} + f"""The file size different is more than 1%, \n - {sf_filename} {sf_size} \n - {pt_filename} {pt_size} """ ) @@ -97,6 +95,8 @@ def convert_file( loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) if "state_dict" in loaded: loaded = loaded["state_dict"] + if "ema_state_dict" in loaded: + loaded = loaded["ema_state_dict"] to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) metadata = {"format": "pt"} @@ -153,4 +153,4 @@ def convert_file( discard_names = get_discard_names(config_path) if config_path else [] convert_file(pt_filename, sf_filename, discard_names) - print(f"Conversion successful! `safetensors` file saved at: {sf_filename}") + print(f"Conversion successful! safetensors file saved at: {sf_filename}") diff --git a/examples/opensora_pku/tools/model_conversion/convert_wfvae.py b/examples/opensora_pku/tools/model_conversion/convert_wfvae.py new file mode 100644 index 0000000000..fd13222e10 --- /dev/null +++ b/examples/opensora_pku/tools/model_conversion/convert_wfvae.py @@ -0,0 +1,79 @@ +import argparse +import os + +import torch +from safetensors.torch import load_file, save_file + + +def load_torch_ckpt(ckpt_file): + # copied from modeling_wfvae.py init_from_ckpt + def init_from_ckpt(path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu") + print("init from " + path) + if "ema_state_dict" in sd and len(sd["ema_state_dict"]) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0: + print("Load from ema model!") + sd = sd["ema_state_dict"] + sd = {key.replace("module.", ""): value for key, value in sd.items()} + elif "state_dict" in sd: + print("Load from normal model!") + if "gen_model" in sd["state_dict"]: + sd = sd["state_dict"]["gen_model"] + else: + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + return sd + + return init_from_ckpt(ckpt_file) + + +def check_file_size(sf_filename: str, pt_filename: str): + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError( + f"""The file size different is more than 1%, \n - {sf_filename} {sf_size} \n - {pt_filename} {pt_size} + """ + ) + + +def convert_file( + pt_filename: str, + sf_filename: str, +): + loaded = load_torch_ckpt(pt_filename) + + # to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) + metadata = {"format": "pt"} + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata=metadata) + check_file_size(sf_filename, pt_filename) + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str, default=None, help="path to torch checkpoint path, e.g., merged.ckpt") + parser.add_argument( + "--target", + type=str, + help="The path to save the converted `safetensors` file (e.g., model.safetensors).", + ) + args = parser.parse_args() + + convert_file(args.src, args.target) + print(f"converted checkpoint saved to {args.target}") diff --git a/examples/opensora_pku/tools/model_conversion/inflate_vae2d_to_vae3d.py b/examples/opensora_pku/tools/model_conversion/inflate_vae2d_to_vae3d.py index a3d02f01d2..2eba9c52c8 100644 --- a/examples/opensora_pku/tools/model_conversion/inflate_vae2d_to_vae3d.py +++ b/examples/opensora_pku/tools/model_conversion/inflate_vae2d_to_vae3d.py @@ -11,6 +11,7 @@ from opensora.utils.ms_utils import init_env import mindspore as ms +from mindspore import mint def inflate(args): @@ -64,7 +65,7 @@ def inflate(args): print(key_2d, shape_3d, shape_2d) if len(shape_3d) > len(shape_2d): w = vae2d_sd[key_2d] - new_w = ms.ops.zeros(shape_3d, dtype=w.dtype) + new_w = mint.zeros(shape_3d, dtype=w.dtype) # tail initialization new_w[:, :, -1, :, :] = w # cin, cout, t, h, w From bf364a6cbf25be649d7d8a895336a1b3d4d775da Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Tue, 29 Oct 2024 18:15:25 +0800 Subject: [PATCH 002/133] add dit, t2v inference --- examples/opensora_pku/README.md | 41 +- .../opensora/models/diffusion/__init__.py | 6 +- .../diffusion/opensora/modeling_opensora.py | 974 +++-------- .../models/diffusion/opensora/modules.py | 1538 ++++------------- .../opensora/sample/pipeline_opensora.py | 882 +++++----- .../sample/{sample_t2v.py => sample.py} | 653 ++++--- .../single-device/sample_t2v_29x1280.sh | 26 + .../single-device/sample_t2v_29x720p.sh | 19 - .../sample_t2v_93x640_1texenc.sh | 25 + .../sample_t2v_93x640_2texenc.sh | 23 + 10 files changed, 1554 insertions(+), 2633 deletions(-) rename examples/opensora_pku/opensora/sample/{sample_t2v.py => sample.py} (51%) create mode 100644 examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x720p.sh create mode 100644 examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_1texenc.sh create mode 100644 examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 670437fb64..b72bbad8e5 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -177,31 +177,48 @@ You can also run video reconstruction given an input video folder. See `scripts/ ### Open-Sora-Plan v1.3.0 Command Line Inference -**To be revised.** +You need download the models manually. +First, you need to download checkpoint including [diffusion model](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640), [vae](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/vae) and [text encoder](https://huggingface.co/google/mt5-xxl), and optional [second text encoder](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k). The [prompt refiner](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner) is optional. + + + + +You can run text-to-video inference on a single Ascend device using the script `scripts/text_condition/single-device/sample_t2v_29x1280.sh` by modifying `--model_path`, `--text_encoder_name_1` and `--ae_path`. The `--caption_refiner` and `--text_encoder_name_2` are optional. + + -You can run text-to-video inference on a single Ascend device using the script `scripts/text_condition/single-device/sample_t2v_29x720p.sh`. ```bash -python opensora/sample/sample_t2v.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.2.0/29x720p \ +# Single NPU +python opensora/sample/sample.py \ + --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ + --version v1_3 \ --num_frames 29 \ - --height 720 \ + --height 704 \ --width 1280 \ - --cache_dir "./" \ - --text_encoder_name google/mt5-xxl \ + --text_encoder_name_1 google/mt5-xxl \ + --text_encoder_name_2 laion/CLIP-ViT-bigG-14-laion2B-39B-b160k \ --text_prompt examples/prompt_list_0.txt \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae\ - --save_img_path "./sample_videos/prompt_list_0_29x720p" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --save_img_path "./sample_videos/prompt_list_0_29x1280_mt5_openclip" \ --fps 24 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ - --model_type "dit" \ + --num_samples_per_prompt 1 \ + --rescale_betas_zero_snr \ + --prediction_type "v_prediction" \ + --mode 1 ``` -You can change the `num_frames`, `height` and `width` to match with the training shape of different checkpoints, e.g., `29x480p` requires `num_frames=29`, `height=480` and `width=640`. In case of oom on your device, you can try to append `--save_memory` to the command above, which enables a more radical tiling strategy for causal vae. +You can change the `num_frames`, `height` and `width`. +Note that DiT model is trained arbitrarily on stride=32. +So keep the resolution of the inference a multiple of 32. `num_frames` needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1. + + +**To be revised.** If you want to run a multi-device inference, e.g., 8 cards, please use `msrun` and pass `--use_parallel=True` as the example below: diff --git a/examples/opensora_pku/opensora/models/diffusion/__init__.py b/examples/opensora_pku/opensora/models/diffusion/__init__.py index eb493c8ba5..3898faaef7 100644 --- a/examples/opensora_pku/opensora/models/diffusion/__init__.py +++ b/examples/opensora_pku/opensora/models/diffusion/__init__.py @@ -1,8 +1,8 @@ -from .opensora.modeling_opensora import OpenSora_models, OpenSora_models_class +from .opensora.modeling_opensora import OpenSora_v1_3_models, OpenSora_v1_3_models_class Diffusion_models = {} -Diffusion_models.update(OpenSora_models) +Diffusion_models.update(OpenSora_v1_3_models) Diffusion_models_class = {} -Diffusion_models_class.update(OpenSora_models_class) +Diffusion_models_class.update(OpenSora_v1_3_models_class) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 05896ff2c7..0d85fafba3 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -17,41 +17,11 @@ from mindone.diffusers.models.normalization import AdaLayerNormSingle from mindone.diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, deprecate -from .modules import BasicTransformerBlock, LayerNorm, OverlapPatchEmbed2D, OverlapPatchEmbed3D, PatchEmbed2D +from examples.opensora_pku.opensora.models.diffusion.opensora.modules import BasicTransformerBlock, LayerNorm, Attention, PatchEmbed2D -logger = logging.getLogger(__name__) - - -class OpenSoraT2V(ModelMixin, ConfigMixin): +class OpenSoraT2V_v1_3(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True - """ - A 2D Transformer model for image-like data. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input and output (specify if the input is **continuous**). - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - num_vector_embeds (`int`, *optional*): - The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. - - During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - """ - @register_to_config def __init__( self, @@ -61,108 +31,91 @@ def __init__( out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, - norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, + attention_bias: bool = True, + sample_size_h: Optional[int] = None, + sample_size_w: Optional[int] = None, sample_size_t: Optional[int] = None, - num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None, activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - attention_type: str = "default", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, caption_channels: int = None, - interpolation_scale_h: float = None, - interpolation_scale_w: float = None, - interpolation_scale_t: float = None, - use_additional_conditions: Optional[bool] = None, - attention_mode: str = "xformers", - downsampler: str = None, - use_recompute=False, - use_rope: bool = False, - FA_dtype=ms.bfloat16, - num_no_recompute: int = 0, - use_stable_fp32: bool = False, + interpolation_scale_h: float = 1.0, + interpolation_scale_w: float = 1.0, + interpolation_scale_t: float = 1.0, + sparse1d: bool = False, + sparse_n: int = 2, + + attention_mode: str = "xformers", #NEW + use_recompute=False, #NEW + FA_dtype=ms.bfloat16, #NEW + num_no_recompute: int = 0, #NEW ): super().__init__() - - # Validate inputs. - if patch_size is not None: - if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: - raise NotImplementedError( - f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." - ) - elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: - raise ValueError( - f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." - ) - # Set some common variables used across the board. - self.use_rope = use_rope - self.use_linear_projection = use_linear_projection - self.interpolation_scale_t = interpolation_scale_t - self.interpolation_scale_h = interpolation_scale_h - self.interpolation_scale_w = interpolation_scale_w - self.downsampler = downsampler - self.caption_channels = caption_channels - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels - self.gradient_checkpointing = use_recompute - self.config.hidden_size = self.inner_dim - use_additional_conditions = False - self.use_additional_conditions = use_additional_conditions - self.use_recompute = use_recompute - self.FA_dtype = FA_dtype - - # 1. Transformer2DModel can process both standard continuous images of shape\ - # `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - assert in_channels is not None and patch_size is not None - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" + self.config.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = use_recompute #NEW + self.use_recompute = use_recompute #NEW + self.FA_dtype = FA_dtype #NEW + self.attention_mode = attention_mode #NEW + self._init_patched_inputs() + + def _init_patched_inputs(self): + + self.config.sample_size = (self.config.sample_size_h, self.config.sample_size_w) + interpolation_scale_thw = ( + self.config.interpolation_scale_t, + self.config.interpolation_scale_h, + self.config.interpolation_scale_w ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - # 2. Initialize the right blocks. - # Initialize the output blocks and other projection blocks when necessary. - self._init_patched_inputs(norm_type=norm_type) - - if self.use_recompute: - num_no_recompute = self.config.num_no_recompute - num_blocks = len(self.transformer_blocks) - assert num_no_recompute >= 0, "Expect to have num_no_recompute as a positive integer." - assert ( - num_no_recompute <= num_blocks - ), "Expect to have num_no_recompute as an integer no greater than the number of blocks," - f"but got {num_no_recompute} and {num_blocks}." - logger.info(f"Excluding {num_no_recompute} blocks from the recomputation list.") - for bidx, block in enumerate(self.transformer_blocks): - if bidx < num_blocks - num_no_recompute: - self.recompute(block) - self.silu = nn.SiLU() - self.maxpool2d = nn.MaxPool2d( - kernel_size=(self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size) + + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.config.caption_channels, hidden_size=self.config.hidden_size + ) + + self.pos_embed = PatchEmbed2D( + patch_size=self.config.patch_size, + in_channels=self.config.in_channels, + embed_dim=self.config.hidden_size, + ) + self.transformer_blocks = nn.CellList( + [ + BasicTransformerBlock( + self.config.hidden_size, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + interpolation_scale_thw=interpolation_scale_thw, + sparse1d=self.config.sparse1d if i > 1 and i < 30 else False, + sparse_n=self.config.sparse_n, + sparse_group=i % 2 == 1, + ) + for i in range(self.config.num_layers) + ] + ) + self.norm_out = LayerNorm(self.config.hidden_size, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = ms.Parameter(ops.randn((2, self.config.hidden_size)) / self.config.hidden_size**0.5) + self.proj_out = nn.Dense( + self.config.hidden_size, self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels ) + self.adaln_single = AdaLayerNormSingle(self.config.hidden_size) self.max_pool3d = nn.MaxPool3d( - kernel_size=(self.patch_size_t, self.patch_size, self.patch_size), - stride=(self.patch_size_t, self.patch_size, self.patch_size), + kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), + stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size) ) # rewrite class method to allow the state dict as input @@ -296,192 +249,25 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - + def get_attention_mask(self, attention_mask): if attention_mask is not None: - if self.config.attention_mode != "math": + if self.attention_mode != "math": attention_mask = attention_mask.to(ms.bool_) return attention_mask - def _init_patched_inputs(self, norm_type): - assert self.config.sample_size_t is not None, "OpenSoraT2V over patched input must provide sample_size_t" - assert self.config.sample_size is not None, "OpenSoraT2V over patched input must provide sample_size" - # assert not (self.config.sample_size_t == 1 and self.config.patch_size_t == 2), "Image do not need patchfy in t-dim" - - self.num_frames = self.config.sample_size_t - self.config.sample_size = to_2tuple(self.config.sample_size) - self.height = self.config.sample_size[0] - self.width = self.config.sample_size[1] - self.patch_size_t = self.config.patch_size_t - self.patch_size = self.config.patch_size - interpolation_scale_t = ( - ((self.config.sample_size_t - 1) // 16 + 1) - if self.config.sample_size_t % 2 == 1 - else self.config.sample_size_t / 16 - ) - interpolation_scale_t = ( - self.config.interpolation_scale_t - if self.config.interpolation_scale_t is not None - else interpolation_scale_t - ) - interpolation_scale = ( - self.config.interpolation_scale_h - if self.config.interpolation_scale_h is not None - else self.config.sample_size[0] / 30, - self.config.interpolation_scale_w - if self.config.interpolation_scale_w is not None - else self.config.sample_size[1] / 40, - ) - - if self.config.downsampler is not None and len(self.config.downsampler) == 9: - self.pos_embed = OverlapPatchEmbed3D( - num_frames=self.config.sample_size_t, - height=self.config.sample_size[0], - width=self.config.sample_size[1], - patch_size_t=self.config.patch_size_t, - patch_size=self.config.patch_size, - in_channels=self.in_channels, - embed_dim=self.inner_dim, - interpolation_scale=interpolation_scale, - interpolation_scale_t=interpolation_scale_t, - use_abs_pos=not self.config.use_rope, - ) - elif self.config.downsampler is not None and len(self.config.downsampler) == 7: - self.pos_embed = OverlapPatchEmbed2D( - num_frames=self.config.sample_size_t, - height=self.config.sample_size[0], - width=self.config.sample_size[1], - patch_size_t=self.config.patch_size_t, - patch_size=self.config.patch_size, - in_channels=self.in_channels, - embed_dim=self.inner_dim, - interpolation_scale=interpolation_scale, - interpolation_scale_t=interpolation_scale_t, - use_abs_pos=not self.config.use_rope, - ) - - else: - self.pos_embed = PatchEmbed2D( - num_frames=self.config.sample_size_t, - height=self.config.sample_size[0], - width=self.config.sample_size[1], - patch_size_t=self.config.patch_size_t, - patch_size=self.config.patch_size, - in_channels=self.in_channels, - embed_dim=self.inner_dim, - interpolation_scale=interpolation_scale, - interpolation_scale_t=interpolation_scale_t, - use_abs_pos=not self.config.use_rope, - ) - interpolation_scale_thw = (interpolation_scale_t, *interpolation_scale) - self.transformer_blocks = nn.CellList( - [ - BasicTransformerBlock( - self.inner_dim, - self.config.num_attention_heads, - self.config.attention_head_dim, - dropout=self.config.dropout, - cross_attention_dim=self.config.cross_attention_dim, - activation_fn=self.config.activation_fn, - num_embeds_ada_norm=self.config.num_embeds_ada_norm, - attention_bias=self.config.attention_bias, - only_cross_attention=self.config.only_cross_attention, - double_self_attention=self.config.double_self_attention, - upcast_attention=self.config.upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=self.config.norm_elementwise_affine, - norm_eps=self.config.norm_eps, - attention_type=self.config.attention_type, - attention_mode=self.config.attention_mode, - FA_dtype=self.config.FA_dtype, - downsampler=self.config.downsampler, - use_rope=self.config.use_rope, - interpolation_scale_thw=interpolation_scale_thw, - ) - for _ in range(self.config.num_layers) - ] - ) - - if self.config.norm_type != "ada_norm_single": - self.norm_out = LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Dense(self.inner_dim, 2 * self.inner_dim) - self.proj_out_2 = nn.Dense( - self.inner_dim, - self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels, - ) - elif self.config.norm_type == "ada_norm_single": - self.norm_out = LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = ms.Parameter(ops.randn(2, self.inner_dim) / self.inner_dim**0.5) - self.proj_out = nn.Dense( - self.inner_dim, - self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels, - ) - - # PixArt-Alpha blocks. - self.adaln_single = None - if self.config.norm_type == "ada_norm_single": - # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use - # additional conditions until we find better name - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=self.use_additional_conditions - ) - - self.caption_projection = None - if self.caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, hidden_size=self.inner_dim - ) - def construct( self, hidden_states: ms.Tensor, timestep: Optional[ms.Tensor] = None, encoder_hidden_states: Optional[ms.Tensor] = None, - added_cond_kwargs: Dict[str, ms.Tensor] = None, - class_labels: Optional[ms.Tensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[ms.Tensor] = None, encoder_attention_mask: Optional[ms.Tensor] = None, - use_image_num: int = 0, + # return_dict: bool = True, + **kwargs, ): - """ - The [`Transformer2DModel`] forward method. - - Args: - hidden_states (`ms.Tensor` of shape `(batch size, num latent pixels)` if discrete, \ - `ms.Tensor` of shape `(batch size, frame, channel, height, width)` if continuous): Input `hidden_states`. - encoder_hidden_states ( `ms.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `ms.Tensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - class_labels ( `ms.Tensor` of shape `(batch size, num classes)`, *optional*): - Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in - `AdaLayerZeroNorm`. - cross_attention_kwargs ( `Dict[str, Any]`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - attention_mask ( `ms.Tensor`, *optional*): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - encoder_attention_mask ( `ms.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batch, sequence_length)` True = keep, False = discard. - * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - Returns: - a noise tensor - """ + batch_size, c, frame, h, w = hidden_states.shape - frame = frame - use_image_num # 20-4=16 - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -492,466 +278,175 @@ def construct( # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - attention_mask_vid, attention_mask_img = None, None if attention_mask is not None and attention_mask.ndim == 4: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) - # b, frame+use_image_num, h, w -> a video with images + # b, frame, h, w -> a video # b, 1, h, w -> only images attention_mask = attention_mask.to(self.dtype) - if get_sequence_parallel_state(): - attention_mask_vid = attention_mask[:, : frame * hccl_info.world_size] # b, frame, h, w - attention_mask_img = attention_mask[:, frame * hccl_info.world_size :] # b, use_image_num, h, w - else: - attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w - attention_mask_img = attention_mask[:, frame:] # b, use_image_num, h, w + attention_mask = attention_mask.unsqueeze(1) # b 1 t h w + attention_mask = self.max_pool3d(attention_mask) + # b 1 t h w -> (b 1) 1 (t h w) + attention_mask = attention_mask.reshape(batch_size, 1, -1) + + # attention_mask = (1 - attention_mask.bool().to(self.dtype)) * -10000.0 #TODO: TBD + attention_mask = self.get_attention_mask(attention_mask) # use bool mask for FA - if attention_mask_vid.numel() > 0: - if self.patch_size_t - 1 > 0: - attention_mask_vid_first_frame = attention_mask_vid[:, :1].repeat(self.patch_size_t - 1, axis=1) - attention_mask_vid = ops.cat([attention_mask_vid_first_frame, attention_mask_vid], axis=1) - attention_mask_vid = attention_mask_vid[:, None, :, :, :] # b 1 t h w - attention_mask_vid = self.max_pool3d(attention_mask_vid) - # b 1 t h w -> (b 1) 1 (t h w) - attention_mask_vid = attention_mask_vid.reshape(batch_size, 1, -1) - if attention_mask_img.numel() > 0: - attention_mask_img = self.maxpool2d(attention_mask_img) - # b i h w -> (b i) 1 (h w) - attention_mask_img = attention_mask_img.reshape(batch_size * attention_mask_img.shape[1], 1, -1) - # do not fill in -10000.0 until MHA - # attention_mask_vid = (1 - attention_mask_vid.bool().to(self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None - # attention_mask_img = (1 - attention_mask_img.bool().to(self.dtype)) * -10000.0 if attention_mask_img.numel() > 0 else None - - if frame == 1 and use_image_num == 0 and not get_sequence_parallel_state(): - attention_mask_img = attention_mask_vid - attention_mask_vid = None # convert encoder_attention_mask to a bias the same way we do for attention_mask - encoder_attention_mask_vid, encoder_attention_mask_img = None, None - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # (b, l) -> (b, 1, l) - encoder_attention_mask = encoder_attention_mask.to(self.dtype) - - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: - encoder_attention_mask = encoder_attention_mask.to(self.dtype) - # b, 1+use_image_num, l -> a video with images - # b, 1, l -> only images - in_t = encoder_attention_mask.shape[1] - encoder_attention_mask_vid = encoder_attention_mask[:, : in_t - use_image_num] # b, 1, l - # b 1 l -> (b 1) 1 l - encoder_attention_mask_vid = encoder_attention_mask_vid if encoder_attention_mask_vid.numel() > 0 else None - encoder_attention_mask_img = encoder_attention_mask[:, in_t - use_image_num :] # b, use_image_num, l - # b i l -> (b i) 1 l - encoder_attention_mask_img = ( - encoder_attention_mask_img.reshape(-1, encoder_attention_mask.shape[-1]).unsqeeze(1) - if encoder_attention_mask_img.numel() > 0 - else None - ) - - if frame == 1 and use_image_num == 0 and not get_sequence_parallel_state(): - encoder_attention_mask_img = encoder_attention_mask_vid - encoder_attention_mask_vid = None + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: + # b, 1, l + # encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + encoder_attention_mask = self.get_attention_mask(encoder_attention_mask) # use bool mask for FA - attention_mask_vid = self.get_attention_mask(attention_mask_vid) # use bool mask for FA - encoder_attention_mask_vid = self.get_attention_mask(encoder_attention_mask_vid) - attention_mask_img = self.get_attention_mask(attention_mask_img) - encoder_attention_mask_img = self.get_attention_mask(encoder_attention_mask_img) - - # # Retrieve lora scale. - # lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 1. Input - frame = ((frame - 1) // self.patch_size_t + 1) if frame % 2 == 1 else frame // self.patch_size_t # patchfy - # print('frame', frame) - height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + frame = ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t # patchfy + height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size - added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - ( - hidden_states_vid, - hidden_states_img, - encoder_hidden_states_vid, - encoder_hidden_states_img, - timestep_vid, - timestep_img, - embedded_timestep_vid, - embedded_timestep_img, - ) = self._operate_on_patched_inputs( - hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, frame, use_image_num + + hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, batch_size, frame ) - # 2. Blocks - # BS H -> S B H if get_sequence_parallel_state(): - if hidden_states_vid is not None: - # b s h -> s b h - hidden_states_vid = hidden_states_vid.swapaxes(0, 1).contiguous() - # b s h -> s b h - encoder_hidden_states_vid = encoder_hidden_states_vid.swapaxes(0, 1).contiguous() - timestep_vid = timestep_vid.view(batch_size, 6, -1).swapaxes(0, 1).contiguous() - # print('timestep_vid', timestep_vid.shape) + # To + # x (t*h*w b d) or (t//sp*h*w b d) + # cond_1 (l b d) or (l//sp b d) + # b s h -> s b h + hidden_states = hidden_states.swapaxes(0, 1).contiguous() + # b s h -> s b h + encoder_hidden_states = encoder_hidden_states.swapaxes(0,1).contiguous() + timestep = timestep.view(batch_size, 6, -1).swapaxes(0, 1).contiguous() + + sparse_mask = {} + # if npu_config is None: + # if get_sequence_parallel_state(): + # head_num = self.config.num_attention_heads // hccl_info.world_size + # else: + # head_num = self.config.num_attention_heads + # else: + head_num = None + for sparse_n in [1, 4]: + sparse_mask[sparse_n] = Attention.prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num) + + # 2. Blocks + for i, block in enumerate(self.transformer_blocks): + if i > 1 and i < 30: + attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][block.attn1.processor.sparse_group] + else: + attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] - for block in self.transformer_blocks: - if hidden_states_vid is not None: - hidden_states_vid = block( - hidden_states_vid, - attention_mask=attention_mask_vid, - encoder_hidden_states=encoder_hidden_states_vid, - encoder_attention_mask=encoder_attention_mask_vid, - timestep=timestep_vid, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - frame=frame, - height=height, - width=width, - ) - if hidden_states_img is not None: - hidden_states_img = block( - hidden_states_img, - attention_mask=attention_mask_img, - encoder_hidden_states=encoder_hidden_states_img, - encoder_attention_mask=encoder_attention_mask_img, - timestep=timestep_img, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - frame=1, - height=height, - width=width, - ) + # if self.training and self.gradient_checkpointing: #TODO: training + + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + frame=frame, + height=height, + width=width, + ) # BSH if get_sequence_parallel_state(): - if hidden_states_vid is not None: - # s b h -> b s h - hidden_states_vid = hidden_states_vid.swapaxes(0, 1).contiguous() + # To (b, t*h*w, h) or (b, t//sp*h*w, h) + # s b h -> b s h + hidden_states = hidden_states.swapaxes(0, 1).contiguous() # 3. Output - output_vid, output_img = None, None - if hidden_states_vid is not None: - output_vid = self._get_output_for_patched_inputs( - hidden_states=hidden_states_vid, - timestep=timestep_vid, - class_labels=class_labels, - embedded_timestep=embedded_timestep_vid, - num_frames=frame, - height=height, - width=width, - ) # b c t h w - if hidden_states_img is not None: - output_img = self._get_output_for_patched_inputs( - hidden_states=hidden_states_img, - timestep=timestep_img, - class_labels=class_labels, - embedded_timestep=embedded_timestep_img, - num_frames=1, - height=height, - width=width, - ) # b c 1 h w - if use_image_num != 0: - # (b i) c 1 h w -> b c i h w - _, c, _, h, w = output_img.shape - output_img = output_img.reshape(-1, use_image_num, c, 1, h, w).swapaxes(1, 2).squeeze(3) - output = None - if output_vid is not None and output_img is not None: - output = ops.cat([output_vid, output_img], axis=2) - elif output_vid is not None: - output = output_vid - elif output_img is not None: - output = output_img - return output - - @classmethod - def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): - if subfolder is not None: - pretrained_model_path = os.path.join(pretrained_model_path, subfolder) - - config_file = os.path.join(pretrained_model_path, "config.json") - if not os.path.isfile(config_file): - raise RuntimeError(f"{config_file} does not exist") - with open(config_file, "r") as f: - config = json.load(f) - - model = cls.from_config(config, **kwargs) - return model - - def construct_with_cfg(self, x, timestep, class_labels=None, cfg_scale=7.0, attention_mask=None): - """ - Forward pass of Latte, but also batches the unconditional forward pass for classifier-free guidance. - """ - # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb - half = x[: len(x) // 2] - combined = ops.cat([half, half], axis=0) - model_out = self.construct(combined, timestep, class_labels=class_labels, attention_mask=attention_mask) - # For exact reproducibility reasons, we apply classifier-free guidance on only - # three channels by default. The standard approach to cfg applies it to all channels. - # This can be done by uncommenting the following line and commenting-out the line following that. - eps, rest = model_out[:, :, : self.in_channels], model_out[:, :, self.in_channels :] - cond_eps, uncond_eps = ops.split(eps, len(eps) // 2, axis=0) - half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) - eps = ops.cat([half_eps, half_eps], axis=0) - return ops.cat([eps, rest], axis=2) - - def recompute(self, b): - if not b._has_config_recompute: - b.recompute(parallel_optimizer_comm_recompute=True) - if isinstance(b, nn.CellList): - self.recompute(b[-1]) - elif ms.get_context("mode") == ms.GRAPH_MODE: - b.add_flags(output_no_recompute=True) + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep, + embedded_timestep=embedded_timestep, + num_frames=frame, + height=height, + width=width, + ) # b c t h w - @classmethod - def load_from_checkpoint(cls, model, ckpt_path): - if os.path.isdir(ckpt_path) or ckpt_path.endswith(".safetensors"): - return cls.load_from_safetensors(model, ckpt_path) - elif ckpt_path.endswith(".ckpt"): - return cls.load_from_ms_checkpoint(ckpt_path) - else: - raise ValueError("Only support safetensors pretrained ckpt or MindSpore pretrained ckpt!") - - @classmethod - def load_from_safetensors(cls, model, ckpt_path): - if os.path.isdir(ckpt_path): - ckpts = glob.glob(os.path.join(ckpt_path, "*.safetensors")) - n_ckpt = len(ckpts) - assert ( - n_ckpt == 1 - ), f"Expect to find only one safetenesors file under {ckpt_path}, but found {n_ckpt} .safetensors files." - model_file = ckpts[0] - pretrained_model_name_or_path = ckpt_path - elif ckpt_path.endswith(".safetensors"): - model_file = ckpt_path - pretrained_model_name_or_path = os.path.dirname(ckpt_path) - state_dict = load_state_dict(model_file, variant=None) - model._convert_deprecated_attention_blocks(state_dict) - - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=False, - ) - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } - logger.info(loading_info) - return model + return output - @classmethod - def load_from_ms_checkpoint(self, model, ckpt_path): - sd = ms.load_checkpoint(ckpt_path) - # filter 'network.' prefix - rm_prefix = ["network."] - all_pnames = list(sd.keys()) - for pname in all_pnames: - for pre in rm_prefix: - if pname.startswith(pre): - new_pname = pname.replace(pre, "") - sd[new_pname] = sd.pop(pname) - m, u = ms.load_param_into_net(model, sd) - print("net param not load: ", m, len(m)) - print("ckpt param not load: ", u, len(u)) - return model + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame): + + hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # (b, t*h*w, d) - def _operate_on_patched_inputs( - self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, frame, use_image_num - ): - # batch_size = hidden_states.shape[0] - hidden_states_vid, hidden_states_img = self.pos_embed(hidden_states.to(self.dtype), frame) - timestep_vid, timestep_img = None, None - embedded_timestep_vid, embedded_timestep_img = None, None - encoder_hidden_states_vid, encoder_hidden_states_img = None, None + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype + ) # b 6d, b d - if self.adaln_single is not None: - if self.use_additional_conditions and added_cond_kwargs is None: - raise ValueError( - "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." - ) - timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype - ) # b 6d, b d - if hidden_states_vid is None: - timestep_img = timestep - embedded_timestep_img = embedded_timestep - else: - timestep_vid = timestep - embedded_timestep_vid = embedded_timestep - if hidden_states_img is not None: - # b d -> (b i) d - timestep_img = timestep.repeat_interleave(use_image_num, dim=0).contiguous() - # b d -> (b i) d - embedded_timestep_img = embedded_timestep_img.repeat_interleave(use_image_num, dim=0).contiguous() - if self.caption_projection is not None: - encoder_hidden_states = self.caption_projection( - encoder_hidden_states - ) # b, 1+use_image_num, l, d or b, 1, l, d - if hidden_states_vid is None: - # b 1 l d -> (b 1) l d - encoder_hidden_states_img = encoder_hidden_states.reshape( - -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] - ) - else: - # b 1 l d -> (b 1) l d - encoder_hidden_states_vid = encoder_hidden_states[:, :1].reshape( - -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] - ) - if hidden_states_img is not None: - encoder_hidden_states_img = encoder_hidden_states[:, 1:].reshape( - -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] - ) + encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1, l, d + assert encoder_hidden_states.shape[1] == 1, f'encoder_hidden_states.shape is {encoder_hidden_states}' + # b 1 l d -> (b 1) l d + encoder_hidden_states = encoder_hidden_states.reshape(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]) - return ( - hidden_states_vid, - hidden_states_img, - encoder_hidden_states_vid, - encoder_hidden_states_img, - timestep_vid, - timestep_img, - embedded_timestep_vid, - embedded_timestep_img, - ) + return hidden_states, encoder_hidden_states, timestep, embedded_timestep + + def _get_output_for_patched_inputs( - self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None - ): - # import ipdb;ipdb.set_trace() - if self.config.norm_type != "ada_norm_single": - conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=self.dtype) - shift, scale = mint.chunk(self.proj_out_1(self.silu(conditioning)), 2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - elif self.config.norm_type == "ada_norm_single": - shift, scale = mint.chunk(self.scale_shift_table[None] + embedded_timestep[:, None], 2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.squeeze(1) if hidden_states.shape[1] == 1 else hidden_states + self, hidden_states, timestep, embedded_timestep, num_frames, height, width + ): + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, axis=1) + hidden_states = self.norm_out(hidden_states) #BSH -> BSH + hidden_states = hidden_states.squeeze(1) if hidden_states.shape[1] == 1 else hidden_states + + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) if hidden_states.shape[1] == 1 else hidden_states # unpatchify - if self.adaln_single is None: - height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( - ( - -1, - num_frames, - height, - width, - self.patch_size_t, - self.patch_size, - self.patch_size, - self.out_channels, - ) + -1, num_frames, height, width, self.config.patch_size_t, self.config.patch_size, self.config.patch_size, self.out_channels ) - # nthwopqc->nctohpwq + # nthwopqc -> nctohpwq hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.reshape( - ( - -1, - self.out_channels, - num_frames * self.patch_size_t, - height * self.patch_size, - width * self.patch_size, - ) + -1, self.out_channels, + num_frames * self.config.patch_size_t, height * self.config.patch_size, width * self.config.patch_size ) - return output +def OpenSoraT2V_v1_3_2B_122(**kwargs): + kwargs.pop('skip_connection', None) + return OpenSoraT2V_v1_3( + num_layers=32, attention_head_dim=96, num_attention_heads=24, patch_size_t=1, patch_size=2, + caption_channels=4096, cross_attention_dim=2304, activation_fn="gelu-approximate", **kwargs + ) -def OpenSoraT2V_S_122(**kwargs): - return OpenSoraT2V( - num_layers=28, - attention_head_dim=96, - num_attention_heads=16, - patch_size_t=1, - patch_size=2, - norm_type="ada_norm_single", - caption_channels=4096, - cross_attention_dim=1536, - **kwargs, - ) - - -def OpenSoraT2V_B_122(**kwargs): - return OpenSoraT2V( - num_layers=32, - attention_head_dim=96, - num_attention_heads=16, - patch_size_t=1, - patch_size=2, - norm_type="ada_norm_single", - caption_channels=4096, - cross_attention_dim=1920, - **kwargs, - ) - - -def OpenSoraT2V_L_122(**kwargs): - return OpenSoraT2V( - num_layers=40, - attention_head_dim=128, - num_attention_heads=16, - patch_size_t=1, - patch_size=2, - norm_type="ada_norm_single", - caption_channels=4096, - cross_attention_dim=2048, - **kwargs, - ) - - -def OpenSoraT2V_ROPE_L_122(**kwargs): - return OpenSoraT2V( - num_layers=32, - attention_head_dim=96, - num_attention_heads=24, - patch_size_t=1, - patch_size=2, - norm_type="ada_norm_single", - caption_channels=4096, - cross_attention_dim=2304, - **kwargs, - ) - - -OpenSora_models = { - "OpenSoraT2V-S/122": OpenSoraT2V_S_122, - "OpenSoraT2V-B/122": OpenSoraT2V_B_122, - "OpenSoraT2V-L/122": OpenSoraT2V_L_122, - "OpenSoraT2V-ROPE-L/122": OpenSoraT2V_ROPE_L_122, +OpenSora_v1_3_models = { + "OpenSoraT2V_v1_3-2B/122": OpenSoraT2V_v1_3_2B_122, # 2.7B } -OpenSora_models_class = { - "OpenSoraT2V-S/122": OpenSoraT2V, - "OpenSoraT2V-B/122": OpenSoraT2V, - "OpenSoraT2V-L/122": OpenSoraT2V, - "OpenSoraT2V-ROPE-L/122": OpenSoraT2V, +OpenSora_v1_3_models_class = { + "OpenSoraT2V_v1_3-2B/122": OpenSoraT2V_v1_3, } -if __name__ == "__main__": +if __name__ == '__main__': from opensora.models.causalvideovae import ae_stride_config - args = type( - "args", - (), - { - "ae": "CausalVAEModel_D4_4x8x8", - "use_rope": True, - "model_max_length": 512, - "max_height": 320, - "max_width": 240, - "num_frames": 1, - "use_image_num": 0, - "interpolation_scale_t": 1, - "interpolation_scale_h": 1, - "interpolation_scale_w": 1, - }, + args = type('args', (), + { + 'ae': "CausalVAEModel_D4_4x8x8", #'WFVAEModel_D8_4x8x8', + 'model_max_length': 300, + 'max_height': 256, + 'max_width': 512, + 'num_frames': 33, + 'compress_kv_factor': 1, + 'interpolation_scale_t': 1, + 'interpolation_scale_h': 1, + 'interpolation_scale_w': 1, + "sparse1d": True, + "sparse_n": 4, + "rank": 64, + } ) - b = 16 + b = 2 c = 8 cond_c = 4096 num_timesteps = 1000 @@ -959,78 +454,47 @@ def OpenSoraT2V_ROPE_L_122(**kwargs): latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w) num_frames = (args.num_frames - 1) // ae_stride_t + 1 - model = OpenSoraT2V_ROPE_L_122( - in_channels=c, - out_channels=c, - sample_size=latent_size, - sample_size_t=num_frames, - activation_fn="gelu-approximate", + model = OpenSoraT2V_v1_3_2B_122( + in_channels=c, + out_channels=c, + sample_size_h=latent_size, + sample_size_w=latent_size, + sample_size_t=num_frames, + # activation_fn="gelu-approximate", attention_bias=True, - attention_type="default", double_self_attention=False, norm_elementwise_affine=False, norm_eps=1e-06, - norm_num_groups=32, - num_vector_embeds=None, only_cross_attention=False, upcast_attention=False, - use_linear_projection=False, - use_additional_conditions=False, - downsampler=None, - interpolation_scale_t=args.interpolation_scale_t, - interpolation_scale_h=args.interpolation_scale_h, - interpolation_scale_w=args.interpolation_scale_w, - use_rope=args.use_rope, + interpolation_scale_t=args.interpolation_scale_t, + interpolation_scale_h=args.interpolation_scale_h, + interpolation_scale_w=args.interpolation_scale_w, + sparse1d=args.sparse1d, + sparse_n=args.sparse_n ) - + try: - path = "PixArt-Alpha-XL-2-512.safetensors" + path = "/home_host/susan/workspace/checkpoints/Open-Sora-Plan-v1.3.0/any93x640x640/diffusion_pytorch_model.safetensors" from safetensors.torch import load_file as safe_load - ckpt = safe_load(path, device="cpu") - # import ipdb;ipdb.set_trace() - if ( - ckpt["pos_embed.proj.weight"].shape != model.pos_embed.proj.weight.shape - and ckpt["pos_embed.proj.weight"].ndim == 4 - ): - repeat = model.pos_embed.proj.weight.shape[2] - ckpt["pos_embed.proj.weight"] = ckpt["pos_embed.proj.weight"].unsqueeze(2).repeat(repeat, axis=2) / float( - repeat - ) - del ckpt["proj_out.weight"], ckpt["proj_out.bias"] - msg = model.load_state_dict(ckpt, strict=False) + msg = model.load_state_dict(ckpt, strict=True) print(msg) except Exception as e: print(e) - print(model) - print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B") + # print(model) + # print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B") # import sys;sys.exit() - x = ops.randn( - b, - c, - 1 + (args.num_frames - 1) // ae_stride_t + args.use_image_num, - args.max_height // ae_stride_h, - args.max_width // ae_stride_w, - ) - cond = ops.randn(b, 1 + args.use_image_num, args.model_max_length, cond_c) - attn_mask = ops.randint( - 0, - 2, - ( - b, - 1 + (args.num_frames - 1) // ae_stride_t + args.use_image_num, - args.max_height // ae_stride_h, - args.max_width // ae_stride_w, - ), - ) # B L or B 1+num_images L - cond_mask = ops.randint(0, 2, (b, 1 + args.use_image_num, args.model_max_length)) # B L or B 1+num_images L + x = ops.randn(b, c, 1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w) + cond = ops.randn(b, 1, args.model_max_length, cond_c) + attn_mask = ops.randint(0, 2, (b, 1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w)) # B L or B 1+num_images L + cond_mask = ops.randint(0, 2, (b, 1, args.model_max_length)) # B L or B 1+num_images L timestep = ops.randint(0, 1000, (b,)) model_kwargs = dict( hidden_states=x, encoder_hidden_states=cond, attention_mask=attn_mask, encoder_attention_mask=cond_mask, - use_image_num=args.use_image_num, timestep=timestep, ) model.set_train(False) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index 0a05a21167..212846dae0 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -1,197 +1,142 @@ import logging import numbers -import re -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple import numpy as np from opensora.acceleration.communications import AllToAll_SBH from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info -from opensora.npu_config import npu_config import mindspore as ms from mindspore import Parameter, mint, nn, ops from mindspore.common.initializer import initializer -from mindone.diffusers.models.attention import FeedForward, GatedSelfAttentionDense +from mindone.diffusers.models.attention import FeedForward from mindone.diffusers.models.attention_processor import Attention as Attention_ -from mindone.diffusers.models.embeddings import SinusoidalPositionalEmbedding -from mindone.diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero from mindone.utils.version_control import check_valid_flash_attention, choose_flash_attention_dtype from .rope import PositionGetter3D, RoPE3D logger = logging.getLogger(__name__) +class LayerNorm(nn.Cell): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine: bool = True, dtype=ms.float32): + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.gamma = Parameter(initializer("ones", normalized_shape, dtype=dtype)) + self.beta = Parameter(initializer("zeros", normalized_shape, dtype=dtype)) + else: + self.gamma = mint.ones(normalized_shape, dtype=dtype) + self.beta = mint.zeros(normalized_shape, dtype=dtype) + self.layer_norm = ops.LayerNorm(-1, -1, epsilon=eps) -# Positional Embeddings -def get_3d_sincos_pos_embed( - embed_dim, - grid_size, - cls_token=False, - extra_tokens=0, - interpolation_scale=1.0, - base_size=16, -): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - # if isinstance(grid_size, int): - # grid_size = (grid_size, grid_size) - grid_t = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale[0] - grid_h = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale[1] - grid_w = np.arange(grid_size[2], dtype=np.float32) / (grid_size[2] / base_size[2]) / interpolation_scale[2] - grid = np.meshgrid(grid_w, grid_h, grid_t) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([3, 1, grid_size[2], grid_size[1], grid_size[0]]) - pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) - - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 3 != 0: - raise ValueError("embed_dim must be divisible by 3") - - # use 1/3 of dimensions to encode grid_t/h/w - emb_t = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (T*H*W, D/3) - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (T*H*W, D/3) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (T*H*W, D/3) - - emb = np.concatenate([emb_t, emb_h, emb_w], axis=1) # (T*H*W, D) - return emb - - -def get_2d_sincos_pos_embed( - embed_dim, - grid_size, - cls_token=False, - extra_tokens=0, - interpolation_scale=1.0, - base_size=16, -): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - # if isinstance(grid_size, int): - # grid_size = (grid_size, grid_size) - - grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale[0] - grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale[1] - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - # use 1/3 of dimensions to encode grid_t/h/w - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed( - embed_dim, - grid_size, - cls_token=False, - extra_tokens=0, - interpolation_scale=1.0, - base_size=16, -): - """ - grid_size: int of the grid return: pos_embed: [grid_size, embed_dim] or - [1+grid_size, embed_dim] (w/ or w/o cls_token) - """ - # if isinstance(grid_size, int): - # grid_size = (grid_size, grid_size) - - grid = np.arange(grid_size, dtype=np.float32) / (grid_size / base_size) / interpolation_scale - pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) # (H*W, D/2) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) + def construct(self, x: ms.Tensor): + x, _, _ = self.layer_norm(x, self.gamma, self.beta) + return x - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product +# Different from v1.2 +class PatchEmbed2D(nn.Cell): + """2D Image to Patch Embedding but with video""" - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + def __init__( + self, + patch_size=16, #2 + in_channels=3, #8 + embed_dim=768, # 24*96=2304 + bias=True, + ): + super().__init__() + self.proj = nn.Conv2d( + in_channels, embed_dim, + kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), has_bias=bias + ) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb + def construct(self, latent): + b, c, t, h, w = latent.shape # b, c=in_channels, t, h, w + # b c t h w -> (b t) c h w + latent = latent.permute(0, 2, 1, 3, 4).reshape(b*t, c, h, w) # b*t, c, h, w + latent = self.proj(latent) # b*t, embed_dim, h, w + # (b t) c h w -> b (t h w) c + _, c, h, w = latent.shape + latent = latent.reshape(b, -1, c, h, w).permute(0, 1, 3, 4, 2).reshape(b, -1, c) # b, t*h*w, embed_dim + + return latent + + +def get_attention_mask(attention_mask, repeat_num, attention_mode="xformers"): + if attention_mask is not None: + if attention_mode != "math": + attention_mask = attention_mask.to(ms.bool_) + else: + attention_mask = attention_mask.repeat_interleave(repeat_num, dim=-2) + return attention_mask class Attention(Attention_): - def __init__(self, downsampler, attention_mode, use_rope, interpolation_scale_thw, **kwags): + def __init__( + self, interpolation_scale_thw, sparse1d, sparse_n, + sparse_group, is_cross_attn, attention_mode="xformers", **kwags + ): FA_dtype = kwags.pop("FA_dtype", ms.bfloat16) - processor = AttnProcessor2_0( + processor = OpenSoraAttnProcessor2_0( + interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, + sparse_group=sparse_group, is_cross_attn=is_cross_attn, attention_mode=attention_mode, - use_rope=use_rope, - interpolation_scale_thw=interpolation_scale_thw, - FA_dtype=FA_dtype, - dim_head=kwags["dim_head"], - ) + FA_dtype=FA_dtype, dim_head=kwags["dim_head"] + ) kwags["processor"] = processor super().__init__(**kwags) if attention_mode == "xformers": self.set_use_memory_efficient_attention_xformers(True) self.processor = processor - self.downsampler = None - if downsampler: # downsampler k155_s122 - downsampler_ker_size = list(re.search(r"k(\d{2,3})", downsampler).group(1)) # 122 - down_factor = list(re.search(r"s(\d{2,3})", downsampler).group(1)) - downsampler_ker_size = [int(i) for i in downsampler_ker_size] - downsampler_padding = [(i - 1) // 2 for i in downsampler_ker_size] - down_factor = [int(i) for i in down_factor] - - if len(downsampler_ker_size) == 2: - self.downsampler = DownSampler2d( - kwags["query_dim"], - kwags["query_dim"], - kernel_size=downsampler_ker_size, - stride=1, - padding=downsampler_padding, - groups=kwags["query_dim"], - down_factor=down_factor, - down_shortcut=True, - ) - elif len(downsampler_ker_size) == 3: - self.downsampler = DownSampler3d( - kwags["query_dim"], - kwags["query_dim"], - kernel_size=downsampler_ker_size, - stride=1, - padding=downsampler_padding, - groups=kwags["query_dim"], - down_factor=down_factor, - down_shortcut=True, - ) + + @staticmethod + def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num): + + attention_mask = attention_mask.unsqueeze(1) + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + l = attention_mask.shape[-1] + if l % (sparse_n * sparse_n) == 0: + pad_len = 0 + else: + pad_len = sparse_n * sparse_n - l % (sparse_n * sparse_n) + + attention_mask_sparse = mint.nn.functional.pad(attention_mask, (0, pad_len, 0, 0), mode="constant", value=-9980.0) + b = attention_mask_sparse.shape[0] + k = sparse_n + m = sparse_n + # b 1 1 (g k) -> (k b) 1 1 g + attention_mask_sparse_1d = attention_mask_sparse.reshape(b, 1, 1, -1, k).permute(4, 0, 1, 2, 3).reshape(b*k, 1, 1, -1) + # b 1 1 (n m k) -> (m b) 1 1 (n k) + attention_mask_sparse_1d_group = attention_mask_sparse.reshape(b, 1, 1, -1, m, k).permute(4, 0, 1, 2, 3, 5).reshape(m*b, 1, 1, -1) + encoder_attention_mask_sparse = encoder_attention_mask.tile((sparse_n, 1, 1, 1)) + # if npu_config is not None: + attention_mask_sparse_1d = get_attention_mask( + attention_mask_sparse_1d, attention_mask_sparse_1d.shape[-1] + ) + attention_mask_sparse_1d_group = get_attention_mask( + attention_mask_sparse_1d_group, attention_mask_sparse_1d_group.shape[-1] + ) + + encoder_attention_mask_sparse_1d = get_attention_mask( + encoder_attention_mask_sparse, attention_mask_sparse_1d.shape[-1] + ) + encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d + # else: + # attention_mask_sparse_1d = attention_mask_sparse_1d.repeat_interleave(head_num, dim=1) + # attention_mask_sparse_1d_group = attention_mask_sparse_1d_group.repeat_interleave(head_num, dim=1) + + # encoder_attention_mask_sparse_1d = encoder_attention_mask_sparse.repeat_interleave(head_num, dim=1) + # encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d + + return { + False: (attention_mask_sparse_1d, encoder_attention_mask_sparse_1d), + True: (attention_mask_sparse_1d_group, encoder_attention_mask_sparse_1d_group) + } def prepare_attention_mask( self, attention_mask: ms.Tensor, target_length: int, batch_size: int, out_dim: int = 3 @@ -210,17 +155,18 @@ def prepare_attention_mask( The output dimension of the attention mask. Can be either `3` or `4`. Returns: - `ms.Tensor`: The prepared attention mask. + `torch.Tensor`: The prepared attention mask. """ head_size = self.heads if get_sequence_parallel_state(): - head_size = head_size // hccl_info.world_size - if attention_mask is None: + head_size = head_size // hccl_info.world_size # e.g, 24 // 8 + + if attention_mask is None: # b 1 t*h*w in sa, b 1 l in ca return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: - attention_mask = ops.pad(attention_mask, (0, target_length), mode="constant", value=0.0) + attention_mask = mint.nn.functional.pad(attention_mask, (0, target_length), mode="constant", value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: @@ -233,30 +179,30 @@ def prepare_attention_mask( @ms.jit_class -class AttnProcessor2_0: +class OpenSoraAttnProcessor2_0: r""" - Processor for implementing scaled dot-product attention or xFormers-like memory efficient attention. + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ - def __init__( - self, - attention_mode="xformers", - use_rope=False, - interpolation_scale_thw=(1, 1, 1), - FA_dtype=ms.bfloat16, - dim_head=64, - ): - self.use_rope = use_rope + def __init__(self, interpolation_scale_thw=(1, 1, 1), + sparse1d=False, sparse_n=2, sparse_group=False, is_cross_attn=True, + FA_dtype=ms.bfloat16, dim_head=64, attention_mode = "xformers"): + self.sparse1d = sparse1d + self.sparse_n = sparse_n + self.sparse_group = sparse_group + self.is_cross_attn = is_cross_attn self.interpolation_scale_thw = interpolation_scale_thw - if self.use_rope: - self._init_rope(interpolation_scale_thw, dim_head=dim_head) self.attention_mode = attention_mode + + self._init_rope(interpolation_scale_thw, dim_head=dim_head) + + self.attention_mode = "xformers" #TBD # Currently we only support setting attention_mode to `flash` or `math` assert self.attention_mode in [ "xformers", "math", ], f"Unsupported attention mode {self.attention_mode}. Currently we only support ['xformers', 'math']!" - self.enable_FA = attention_mode == "xformers" + self.enable_FA = self.attention_mode == "xformers" self.FA_dtype = FA_dtype assert self.FA_dtype in [ms.float16, ms.bfloat16], f"Unsupported flash-attention dtype: {self.FA_dtype}" if self.enable_FA: @@ -309,9 +255,9 @@ def run_ms_flash_attention( value = value.view(Bs, key_tokens, heads, -1) # Head dimension is checked in Attention.set_use_memory_efficient_attention_xformers. We maybe pad on head_dim. if attn.head_dim_padding > 0: - query_padded = ops.pad(query, (0, attn.head_dim_padding), mode="constant", value=0.0) - key_padded = ops.pad(key, (0, attn.head_dim_padding), mode="constant", value=0.0) - value_padded = ops.pad(value, (0, attn.head_dim_padding), mode="constant", value=0.0) + query_padded = mint.nn.functional.pad(query, (0, attn.head_dim_padding), mode="constant", value=0.0) + key_padded = mint.nn.functional.pad(key, (0, attn.head_dim_padding), mode="constant", value=0.0) + value_padded = mint.nn.functional.pad(value, (0, attn.head_dim_padding), mode="constant", value=0.0) else: query_padded, key_padded, value_padded = query, key, value flash_attn = ops.operations.nn_ops.FlashAttentionScore( @@ -322,7 +268,7 @@ def run_ms_flash_attention( attention_mask = ~attention_mask if attention_mask.dtype == ms.bool_ else 1 - attention_mask # (b, 1, 1, k_n) - > (b, 1, q_n, k_n), manual broadcast if attention_mask.shape[-2] == 1: - attention_mask = mint.tile(attention_mask.bool(), (1, 1, query_tokens, 1)) + attention_mask = mint.tile(attention_mask.bool(), (1, 1, query_tokens, 1)) attention_mask = attention_mask.to(self.fa_mask_dtype) if input_layout == "BNSD": @@ -371,137 +317,207 @@ def run_math_attention(self, attn, query, key, value, attention_mask): assert attention_mask.shape[1] == 1 attention_mask = attention_mask.repeat_interleave(_head_size, 1) attention_mask = attention_mask.reshape(-1, attention_mask.shape[-2], attention_mask.shape[-1]) - attention_mask = ops.zeros(attention_mask.shape).masked_fill(attention_mask.to(ms.bool_), -10000.0) + attention_mask = mint.zeros(attention_mask.shape).masked_fill(attention_mask.to(ms.bool_), -10000.0) attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = ops.bmm(attention_probs, value) + hidden_states = mint.bmm(attention_probs, value) hidden_states = self._batch_to_head_dim(_head_size, hidden_states) return hidden_states - def _batch_to_head_dim(self, head_size, tensor: ms.Tensor) -> ms.Tensor: - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def _head_to_batch_dim(self, head_size, tensor: ms.Tensor, out_dim: int = 3) -> ms.Tensor: - if tensor.ndim == 3: - batch_size, seq_len, dim = tensor.shape - extra_dim = 1 + # TODO: need consider shapes for parallel seq and non-parallel cases + def _sparse_1d(self, x, frame, height, width): + """ + require the shape of (ntokens x batch_size x dim) + + Convert to sparse groups + Input: + x: shape in S,B,D + Output: + x: shape if sparse_group: (S//sparse_n, sparse_n*B, D), else: (S//sparse_n, sparse_n*B, D) + pad_len: 0 or padding + """ + l = x.shape[0] + assert l == frame*height*width + pad_len = 0 + if l % (self.sparse_n * self.sparse_n) != 0: + pad_len = self.sparse_n * self.sparse_n - l % (self.sparse_n * self.sparse_n) + if pad_len != 0: + x = mint.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_len), mode="constant", value=0.0) + + _, b, d = x.shape + if not self.sparse_group: + # (g k) b d -> g (k b) d + k = self.sparse_n + x = x.reshape(-1, k, b, d).reshape(-1, k * b, d) else: - batch_size, extra_dim, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) - - if out_dim == 3: - tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) - - return tensor + # (n m k) b d -> (n k) (m b) d + m = self.sparse_n + k = self.sparse_n + x = x.reshape(-1, m, k, b, d).permute(0, 2, 1, 3, 4).reshape(-1, m*b, d) + + return x, pad_len + + def _reverse_sparse_1d(self, x, frame, height, width, pad_len): + """ + require the shape of (ntokens x batch_size x dim) + Convert sparse groups back to original dimension + Input: + x: shape in S,B,D + Output: + x: shape if sparse_group: (S*sparse_n, B//sparse_n, D), else: (S*sparse_n, B//sparse_n, D) + """ + assert x.shape[0] == (frame*height*width+pad_len) // self.sparse_n + g, _, d = x.shape + if not self.sparse_group: + # g (k b) d -> (g k) b d + k = self.sparse_n + x = x.reshape(g, k, -1, d).reshape(g*k, -1, d) + else: + # (n k) (m b) d -> (n m k) b d + m = self.sparse_n + k = self.sparse_n + assert g % k == 0 + n = g // k + x = x.reshape(n, k, m, -1, d).permute(0, 2, 1, 3, 4).reshape(n*m*k, -1, d) + x = x[:frame*height*width, :, :] + return x + + def _sparse_1d_kv(self, x): + """ + require the shape of (ntokens x batch_size x dim) + """ + # s b d -> s (k b) d + x = x.repeat(self.sparse_n, axis = 1) + return x + def __call__( self, attn: Attention, - hidden_states: ms.Tensor, - encoder_hidden_states: Optional[ms.Tensor] = None, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, attention_mask: Optional[ms.Tensor] = None, temb: Optional[ms.Tensor] = None, - frame: int = 8, - height: int = 16, - width: int = 16, + frame: int = 8, + height: int = 16, + width: int = 16, + *args, + **kwargs, ) -> ms.Tensor: - if attn.downsampler is not None: - hidden_states, attention_mask = attn.downsampler(hidden_states, attention_mask, t=frame, h=height, w=width) - frame, height, width = attn.downsampler.t, attn.downsampler.h, attn.downsampler.w - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).swapaxes(1, 2) - else: - channel = None + residual = hidden_states if get_sequence_parallel_state(): sequence_length, batch_size, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - else: batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.swapaxes(1, 2)).swapaxes(1, 2) + ) #BSH + # print(f"hidden_states.shape {hidden_states.shape}") #BSH query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - if get_sequence_parallel_state(): - query = query.view(-1, attn.heads, head_dim) # [s // sp, b, h * d] -> [s // sp * b, h, d] - key = key.view(-1, attn.heads, head_dim) - value = value.view(-1, attn.heads, head_dim) - # query = attn.q_norm(query) - # key = attn.k_norm(key) - h_size = attn.heads * head_dim + FA_head_num = attn.heads + total_frame = frame + + if get_sequence_parallel_state(): #TODO: to test sp_size = hccl_info.world_size - h_size_sp = h_size // sp_size + FA_head_num = attn.heads // sp_size + total_frame = frame * sp_size # apply all_to_all to gather sequence and split attention heads [s // sp * b, h, d] -> [s * b, h // sp, d] - query = self.alltoall_sbh_q(query).view(-1, batch_size, h_size_sp) - key = self.alltoall_sbh_k(key).view(-1, batch_size, h_size_sp) - value = self.alltoall_sbh_v(value).view(-1, batch_size, h_size_sp) - - if self.use_rope: - query = query.view(-1, batch_size, attn.heads // sp_size, head_dim) - key = key.view(-1, batch_size, attn.heads // sp_size, head_dim) - # require the shape of (batch_size x nheads x ntokens x dim) - pos_thw = self.position_getter(batch_size, t=frame * sp_size, h=height, w=width) + query = self.alltoall_sbh_q(query.view(-1, attn.heads, head_dim)) + key = self.alltoall_sbh_k(key.view(-1, attn.heads, head_dim)) + value = self.alltoall_sbh_v(value.view(-1, attn.heads, head_dim)) + + # print(f'batch: {batch_size}, FA_head_num: {FA_head_num}, head_dim: {head_dim}, total_frame:{total_frame}') + query = query.view(-1, batch_size, FA_head_num, head_dim)# BUG? TODO: to test + key = key.view(-1, batch_size, FA_head_num, head_dim) #BUG ? + + # print(f'q {query.shape}, k {key.shape}, v {value.shape}') + if not self.is_cross_attn: + # require the shape of (ntokens x batch_size x nheads x dim) + pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width) + # print(f'pos_thw {pos_thw}') query = self.rope(query, pos_thw) key = self.rope(key, pos_thw) - query = query.view(-1, batch_size, h_size_sp).swapaxes(0, 1) # SBH to BSH - key = key.view(-1, batch_size, h_size_sp).swapaxes(0, 1) - value = value.view(-1, batch_size, h_size_sp).swapaxes(0, 1) - if self.attention_mode == "math": - # FIXME: shape error - hidden_states = self.run_math_attention(attn, query, key, value, attention_mask) - elif self.attention_mode == "xformers": - hidden_states = self.run_ms_flash_attention(attn, query, key, value, attention_mask) - # [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] - hidden_states = hidden_states.view(batch_size, -1, attn.heads // sp_size, head_dim).transpose(2, 1, 0, 3) - hidden_states = self.alltoall_sbh_out(hidden_states).transpose(1, 2, 0, 3).view(-1, batch_size, h_size) + + query = query.view(-1, batch_size, FA_head_num * head_dim) + key = key.view(-1, batch_size, FA_head_num * head_dim) + value = value.view(-1, batch_size, FA_head_num * head_dim) else: - query = query.view(batch_size, -1, attn.heads, head_dim) - key = key.view(batch_size, -1, attn.heads, head_dim) - # query = attn.q_norm(query) - # key = attn.k_norm(key) - if self.use_rope: + # print(f'batch: {batch_size}, FA_head_num: {FA_head_num}, head_dim: {head_dim}, total_frame:{total_frame}') + query = query.view(batch_size, -1, FA_head_num, head_dim) + key = key.view(batch_size, -1, FA_head_num, head_dim) + # (batch_size x ntokens x nheads x dim) + + # print(f'q {query.shape}, k {key.shape}, v {value.shape}') + if not self.is_cross_attn: # require the shape of (batch_size x ntokens x nheads x dim) - pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width) + pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width) + # print(f'pos_thw {pos_thw}') query = self.rope(query, pos_thw) key = self.rope(key, pos_thw) - query = query.view(batch_size, -1, attn.heads * head_dim) - key = key.view(batch_size, -1, attn.heads * head_dim) + + query = query.view(batch_size, -1, FA_head_num * head_dim).swapaxes(0, 1) + key = key.view(batch_size, -1, FA_head_num * head_dim).swapaxes(0, 1) + value = value.swapaxes(0, 1) + + # print(f'q {query.shape}, k {key.shape}, v {value.shape}') #(SBH) + + if self.sparse1d: + query, pad_len = self._sparse_1d(query, total_frame, height, width) + if self.is_cross_attn: + key = self._sparse_1d_kv(key) + value = self._sparse_1d_kv(value) + else: + key, pad_len = self._sparse_1d(key, total_frame, height, width) + value, pad_len = self._sparse_1d(value, total_frame, height, width) + + + # print(f'q {query.shape}, k {key.shape}, v {value.shape}') + query = query.swapaxes(0, 1) # SBH to BSH + key = key.swapaxes(0, 1) + value = value.swapaxes(0, 1) + if self.attention_mode == "math": + # FIXME: shape error + hidden_states = self.run_math_attention(attn, query, key, value, attention_mask) + elif self.attention_mode == "xformers": + hidden_states = self.run_ms_flash_attention(attn, query, key, value, attention_mask) + # if npu_config is not None: + # hidden_states = npu_config.run_attention(query, key, value, attention_mask, "SBH", head_dim, FA_head_num) + # else: + # query = rearrange(query, 's b (h d) -> b h s d', h=FA_head_num) + # key = rearrange(key, 's b (h d) -> b h s d', h=FA_head_num) + # value = rearrange(value, 's b (h d) -> b h s d', h=FA_head_num) + # # 0, -10000 ->(bool) False, True ->(any) True ->(not) False + # # 0, 0 ->(bool) False, False ->(any) False ->(not) True + # # if attention_mask is None or not torch.any(attention_mask.bool()): # 0 mean visible + # # attention_mask = None + # # the output of sdp = (batch, num_heads, seq_len, head_dim) + # with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True): + # hidden_states = scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) # dropout_p=0.0, is_causal=False + # hidden_states = rearrange(hidden_states, 'b h s d -> s b (h d)', h=FA_head_num) + + if self.sparse1d: + hidden_states = hidden_states.swapaxes(0, 1) # BSH -> SBH + hidden_states = self._reverse_sparse_1d(hidden_states, total_frame, height, width, pad_len) + hidden_states = hidden_states.swapaxes(0, 1) # SBH -> BSH + + # [s, b, h // sp * d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] + if get_sequence_parallel_state(): + hidden_states = self.alltoall_sbh_out(hidden_states.reshape(-1, FA_head_num, head_dim)) + hidden_states = hidden_states.view(-1, batch_size, inner_dim) - if self.attention_mode == "math": - hidden_states = self.run_math_attention(attn, query, key, value, attention_mask) - elif self.attention_mode == "xformers": - hidden_states = self.run_ms_flash_attention(attn, query, key, value, attention_mask) hidden_states = hidden_states.to(query.dtype) # linear proj @@ -509,673 +525,11 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) - if input_ndim == 4: - hidden_states = hidden_states.swapaxes(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - hidden_states = hidden_states / attn.rescale_output_factor - if attn.downsampler is not None: - hidden_states = attn.downsampler.reverse(hidden_states, t=frame, h=height, w=width) return hidden_states - -class LayerNorm(nn.Cell): - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine: bool = True, dtype=ms.float32): - super().__init__() - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = tuple(normalized_shape) - self.eps = eps - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - self.gamma = Parameter(initializer("ones", normalized_shape, dtype=dtype)) - self.beta = Parameter(initializer("zeros", normalized_shape, dtype=dtype)) - else: - self.gamma = ops.ones(normalized_shape, dtype=dtype) - self.beta = ops.zeros(normalized_shape, dtype=dtype) - self.layer_norm = ops.LayerNorm(-1, -1, epsilon=eps) - - def construct(self, x: ms.Tensor): - x, _, _ = self.layer_norm(x, self.gamma, self.beta) - return x - - -class PatchEmbed2D(nn.Cell): - """2D Image to Patch Embedding but with 3D positional embedding""" - - def __init__( - self, - num_frames=1, - height=224, - width=224, - patch_size_t=1, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - interpolation_scale=(1, 1), - interpolation_scale_t=1, - use_abs_pos=True, - ): - super().__init__() - # assert num_frames == 1 - self.use_abs_pos = use_abs_pos - self.flatten = flatten - self.layer_norm = layer_norm - self.embed_dim = embed_dim - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), has_bias=bias - ) - if layer_norm: - self.norm = LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - self.patch_size_t = patch_size_t - self.patch_size = patch_size - # See: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 - self.height, self.width = height // patch_size, width // patch_size - self.base_size = (height // patch_size, width // patch_size) - self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1]) - pos_embed = get_2d_sincos_pos_embed( - embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale - ) - self.pos_embed = ms.Tensor(pos_embed).float().unsqueeze(0) - self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t - self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t - self.interpolation_scale_t = interpolation_scale_t - - if get_sequence_parallel_state(): - self.sp_size = hccl_info.world_size - rank_offset = hccl_info.rank % hccl_info.world_size - num_frames = (self.num_frames + self.sp_size - 1) // self.sp_size * self.sp_size - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim, num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t - ) - num_frames //= self.sp_size - self.temp_pos_st = rank_offset * num_frames - self.temp_pos_ed = (rank_offset + 1) * num_frames - else: - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t - ) - - self.temp_pos_embed = ms.Tensor(temp_pos_embed).float().unsqueeze(0) - - def construct(self, latent, num_frames): - b, c, t, h, w = latent.shape - video_latent, image_latent = None, None - height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - # b c t h w -> (b t) c h w - latent = latent.swapaxes(1, 2).reshape(b * t, c, h, w) - - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(start_dim=2).permute(0, 2, 1) # BCHW -> BNC - if self.layer_norm: - latent = self.norm(latent) - - if self.use_abs_pos: - # Interpolate positional embeddings if needed. - # (For PixArt-Alpha: https://github.com/PixArt-alpha/\ - # PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) - if self.height != height or self.width != width: - # raise NotImplementedError - pos_embed = get_2d_sincos_pos_embed( - embed_dim=self.pos_embed.shape[-1], - grid_size=(height, width), - base_size=self.base_size, - interpolation_scale=self.interpolation_scale, - ) - pos_embed = ms.Tensor(pos_embed) - pos_embed = pos_embed.float().unsqueeze(0) - else: - pos_embed = self.pos_embed - - if self.num_frames != num_frames: - if get_sequence_parallel_state(): - # f, h -> f, 1, h - temp_pos_embed = self.temp_pos_embed[self.temp_pos_st : self.temp_pos_ed].unsqueeze(1) - else: - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim=self.temp_pos_embed.shape[-1], - grid_size=num_frames, - base_size=self.base_size_t, - interpolation_scale=self.interpolation_scale_t, - ) - temp_pos_embed = ms.Tensor(temp_pos_embed) - temp_pos_embed = temp_pos_embed.float().unsqueeze(0) - else: - temp_pos_embed = self.temp_pos_embed - - latent = (latent + pos_embed).to(latent.dtype) - - # (b t) n c -> b t n c - latent = latent.reshape(b, t, -1, self.embed_dim) - video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] - - if self.use_abs_pos: - # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() - temp_pos_embed = temp_pos_embed.unsqueeze(2) - video_latent = ( - (video_latent + temp_pos_embed).to(video_latent.dtype) - if video_latent is not None and video_latent.numel() > 0 - else None - ) - image_latent = ( - (image_latent + temp_pos_embed[:, :1]).to(image_latent.dtype) - if image_latent is not None and image_latent.numel() > 0 - else None - ) - # 'b t n c -> b (t n) c' - video_latent = ( - video_latent.reshape(b, -1, self.embed_dim) - if video_latent is not None and video_latent.numel() > 0 - else None - ) - # 'b t n c -> (b t) n c' - image_latent = ( - image_latent.reshape(b * t, -1, self.embed_dim) - if image_latent is not None and image_latent.numel() > 0 - else None - ) - - if num_frames == 1 and image_latent is None and not get_sequence_parallel_state(): - image_latent = video_latent - video_latent = None - - return video_latent, image_latent - - -class OverlapPatchEmbed3D(nn.Cell): - """2D Image to Patch Embedding but with 3D positional embedding""" - - def __init__( - self, - num_frames=1, - height=224, - width=224, - patch_size_t=1, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - interpolation_scale=(1, 1), - interpolation_scale_t=1, - use_abs_pos=True, - ): - super().__init__() - # assert num_frames == 1 - self.use_abs_pos = use_abs_pos - self.flatten = flatten - self.layer_norm = layer_norm - self.embed_dim = embed_dim - - self.proj = nn.Conv3d( - in_channels, - embed_dim, - kernel_size=(patch_size_t, patch_size, patch_size), - stride=(patch_size_t, patch_size, patch_size), - has_bias=bias, - ) - if layer_norm: - self.norm = LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - self.patch_size_t = patch_size_t - self.patch_size = patch_size - # See: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 - self.height, self.width = height // patch_size, width // patch_size - self.base_size = (height // patch_size, width // patch_size) - self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1]) - pos_embed = get_2d_sincos_pos_embed( - embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale - ) - self.pos_embed = ms.Tensor(pos_embed).float().unsqueeze(0) - self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t - self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t - self.interpolation_scale_t = interpolation_scale_t - - if get_sequence_parallel_state(): - self.sp_size = hccl_info.world_size - rank_offset = hccl_info.rank % hccl_info.world_size - num_frames = (self.num_frames + self.sp_size - 1) // self.sp_size * self.sp_size - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim, num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t - ) - num_frames //= self.sp_size - self.temp_pos_st = rank_offset * num_frames - self.temp_pos_ed = (rank_offset + 1) * num_frames - else: - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t - ) - - self.temp_pos_embed = ms.Tensor(temp_pos_embed).float().unsqueeze(0) - - def construct(self, latent, num_frames): - b, c, t, h, w = latent.shape - video_latent, image_latent = None, None - height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - - if npu_config is not None and npu_config.on_npu: - latent = npu_config.run_conv3d(self.proj, latent, latent.dtype) - else: - latent = self.proj(latent) - - if self.flatten: - # b c t h w -> (b t) (h w) c - latent = latent.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c) - if self.layer_norm: - latent = self.norm(latent) - - if self.use_abs_pos: - # Interpolate positional embeddings if needed. - # (For PixArt-Alpha: https://github.com/PixArt-alpha/\ - # PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) - if self.height != height or self.width != width: - # raise NotImplementedError - pos_embed = get_2d_sincos_pos_embed( - embed_dim=self.pos_embed.shape[-1], - grid_size=(height, width), - base_size=self.base_size, - interpolation_scale=self.interpolation_scale, - ) - pos_embed = ms.Tensor(pos_embed) - pos_embed = pos_embed.float().unsqueeze(0) - else: - pos_embed = self.pos_embed - - if self.num_frames != num_frames: - if get_sequence_parallel_state(): - # f, h -> f, 1, h - temp_pos_embed = self.temp_pos_embed[self.temp_pos_st : self.temp_pos_ed].unsqueeze(1) - else: - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim=self.temp_pos_embed.shape[-1], - grid_size=num_frames, - base_size=self.base_size_t, - interpolation_scale=self.interpolation_scale_t, - ) - temp_pos_embed = ms.Tensor(temp_pos_embed) - temp_pos_embed = temp_pos_embed.float().unsqueeze(0) - else: - temp_pos_embed = self.temp_pos_embed - - latent = (latent + pos_embed).to(latent.dtype) - - # (b t) n c -> b t n c - latent = latent.reshape(b, t, -1, self.embed_dim) - video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] - - if self.use_abs_pos: - # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() - temp_pos_embed = temp_pos_embed.unsqueeze(2) - video_latent = ( - (video_latent + temp_pos_embed).to(video_latent.dtype) - if video_latent is not None and video_latent.numel() > 0 - else None - ) - image_latent = ( - (image_latent + temp_pos_embed[:, :1]).to(image_latent.dtype) - if image_latent is not None and image_latent.numel() > 0 - else None - ) - # 'b t n c -> b (t n) c' - video_latent = ( - video_latent.reshape(b, -1, self.embed_dim) - if video_latent is not None and video_latent.numel() > 0 - else None - ) - # 'b t n c -> (b t) n c' - image_latent = ( - image_latent.reshape(b * t, -1, self.embed_dim) - if image_latent is not None and image_latent.numel() > 0 - else None - ) - - if num_frames == 1 and image_latent is None: - image_latent = video_latent - video_latent = None - - return video_latent, image_latent - - -class OverlapPatchEmbed2D(nn.Cell): - """2D Image to Patch Embedding but with 3D positional embedding""" - - def __init__( - self, - num_frames=1, - height=224, - width=224, - patch_size_t=1, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - interpolation_scale=(1, 1), - interpolation_scale_t=1, - use_abs_pos=True, - ): - super().__init__() - assert patch_size_t == 1 - self.use_abs_pos = use_abs_pos - self.flatten = flatten - self.layer_norm = layer_norm - self.embed_dim = embed_dim - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), has_bias=bias - ) - if layer_norm: - self.norm = LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - self.patch_size_t = patch_size_t - self.patch_size = patch_size - # See: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 - self.height, self.width = height // patch_size, width // patch_size - self.base_size = (height // patch_size, width // patch_size) - self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1]) - pos_embed = get_2d_sincos_pos_embed( - embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale - ) - self.pos_embed = ms.Tensor(pos_embed).float().unsqueeze(0) - self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t - self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t - self.interpolation_scale_t = interpolation_scale_t - - if get_sequence_parallel_state(): - self.sp_size = hccl_info.world_size - rank_offset = hccl_info.rank % hccl_info.world_size - num_frames = (self.num_frames + self.sp_size - 1) // self.sp_size * self.sp_size - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim, num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t - ) - num_frames //= self.sp_size - self.temp_pos_st = rank_offset * num_frames - self.temp_pos_ed = (rank_offset + 1) * num_frames - else: - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t - ) - - self.temp_pos_embed = ms.Tensor(temp_pos_embed).float().unsqueeze(0) - - def construct(self, latent, num_frames): - b, c, t, h, w = latent.shape - video_latent, image_latent = None, None - height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size - # b c t h w -> (bt) c h w - latent = latent.swapaxes(1, 2).reshape(b * t, c, h, w) - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(start_dim=2).permute(0, 2, 1) # BT C H W -> BT N C - if self.layer_norm: - latent = self.norm(latent) - - if self.use_abs_pos: - # Interpolate positional embeddings if needed. - # (For PixArt-Alpha: https://github.com/PixArt-alpha/\ - # PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) - if self.height != height or self.width != width: - # raise NotImplementedError - pos_embed = get_2d_sincos_pos_embed( - embed_dim=self.pos_embed.shape[-1], - grid_size=(height, width), - base_size=self.base_size, - interpolation_scale=self.interpolation_scale, - ) - pos_embed = ms.Tensor(pos_embed) - pos_embed = pos_embed.float().unsqueeze(0) - else: - pos_embed = self.pos_embed - - if self.num_frames != num_frames: - if get_sequence_parallel_state(): - # f, h -> f, 1, h - temp_pos_embed = self.temp_pos_embed[self.temp_pos_st : self.temp_pos_ed].unsqueeze(1) - else: - temp_pos_embed = get_1d_sincos_pos_embed( - embed_dim=self.temp_pos_embed.shape[-1], - grid_size=num_frames, - base_size=self.base_size_t, - interpolation_scale=self.interpolation_scale_t, - ) - temp_pos_embed = ms.Tensor(temp_pos_embed) - temp_pos_embed = temp_pos_embed.float().unsqueeze(0) - else: - temp_pos_embed = self.temp_pos_embed - - latent = (latent + pos_embed).to(latent.dtype) - - # (b t) n c -> b t n c - latent = latent.reshape(b, t, -1, self.embed_dim) - video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] - - if self.use_abs_pos: - # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() - temp_pos_embed = temp_pos_embed.unsqueeze(2) - video_latent = ( - (video_latent + temp_pos_embed).to(video_latent.dtype) - if video_latent is not None and video_latent.numel() > 0 - else None - ) - image_latent = ( - (image_latent + temp_pos_embed[:, :1]).to(image_latent.dtype) - if image_latent is not None and image_latent.numel() > 0 - else None - ) - # 'b t n c -> b (t n) c' - video_latent = ( - video_latent.reshape(b, -1, self.embed_dim) - if video_latent is not None and video_latent.numel() > 0 - else None - ) - # 'b t n c -> (b t) n c' - image_latent = ( - image_latent.reshape(b * t, -1, self.embed_dim) - if image_latent is not None and image_latent.numel() > 0 - else None - ) - - if num_frames == 1 and image_latent is None: - image_latent = video_latent - video_latent = None - - return video_latent, image_latent - - -class DownSampler3d(nn.Cell): - def __init__(self, *args, **kwargs): - """Required kwargs: down_factor, downsampler""" - super().__init__() - self.down_factor = kwargs.pop("down_factor") - self.down_shortcut = kwargs.pop("down_shortcut") - self.layer = nn.Conv3d(*args, **kwargs) - - def construct(self, x, attention_mask, t, h, w): - b = x.shape[0] - # b (t h w) d -> b d t h w - x = x.reshape(b, t, h, w, -1).permute(0, 4, 1, 2, 3) - - x_dtype = x.dtype - if npu_config is not None and npu_config.on_npu: - conv_out = npu_config.run_conv3d(self.layer, x, x_dtype) - else: - conv_out = self.layer(x) - x = conv_out + (x if self.down_shortcut else 0) - - # b d (t dt) (h dh) (w dw) -> (b dt dh dw) (t h w) d - dt, dh, dw = self.down_factor - x = x.reshape(b, -1, t // dt, dt, h // dh, dh, w // dw, dw) - x = x.permute(0, 3, 5, 7, 2, 4, 6, 1).reshape(b * dt * dw * dh, -1, x.shape[1]) - # b 1 (t h w) -> b 1 t h w - attention_mask = attention_mask.reshape(b, 1, t, h, w) - # b 1 (t dt) (h dh) (w dw) -> (b dt dh dw) 1 (t h w) - attention_mask = attention_mask.reshape(b, 1, t // dt, dt, h // dh, dh, w // dw, dw) - attention_mask = attention_mask.permute(0, 3, 5, 7, 1, 2, 4, 6).reshape(b * dt * dh * dw, 1, -1) - - return x, attention_mask - - def reverse(self, x, t, h, w): - d = x.shape[2] - dt, dh, dw = self.down_factor - # (b dt dh dw) (t h w) d -> b (t dt h dh w dw) d - x = x.reshape(-1, dt, dh, dw, t, h, w, d) - x = x.permute(0, 4, 1, 5, 2, 6, 3, 7).reshape(-1, t * dt * h * dt * w * dw, d) - return x - - -class DownSampler2d(nn.Cell): - def __init__(self, *args, **kwargs): - """Required kwargs: down_factor, downsampler""" - super().__init__() - self.down_factor = kwargs.pop("down_factor") - self.down_shortcut = kwargs.pop("down_shortcut") - self.layer = nn.Conv2d(*args, **kwargs) - - def construct(self, x, attention_mask, t, h, w): - b = x.shape[0] - d = x.shape[-1] - # b (t h w) d -> (b t) d h w - x = x.reshape(b, t, h, w, -1).permute(0, 1, 4, 2, 3).reshape(b * t, d, h, w) - x = self.layer(x) + (x if self.down_shortcut else 0) - - dh, dw = self.down_factor - # b d (h dh) (w dw) -> (b dh dw) (h w) d - x = x.reshape(b, d, h // dh, dh, w // dw, dw) - x = x.permute(0, 3, 5, 2, 4, 1).reshape(b * dh * dw, -1, d) - # b 1 (t h w) -> (b t) 1 h w - attention_mask = attention_mask.reshape(b, 1, t, h, w).swapaxes(1, 2).reshape(b * t, 1, h, w) - # b 1 (h dh) (w dw) -> (b dh dw) 1 (h w) - attention_mask = attention_mask.reshape(b, 1, h // dh, dh, w // dw, dw) - attention_mask = attention_mask.permute(0, 3, 5, 1, 2, 4).reshape(b * dh * dw, 1, -1) - - return x, attention_mask - - def reverse(self, x, t, h, w): - # (b t dh dw) (h w) d -> b (t h dh w dw) d - d = x.shape[-1] - dh, dw = self.down_factor - x = x.reshape(-1, t, dh, dw, h, w, d) - x = x.permute(0, 1, 4, 2, 5, 3, 6).reshape(-1, t * h * dh * w * dw, d) - return x - - -class FeedForward_Conv2d(nn.Cell): - def __init__(self, downsampler, dim, hidden_features, bias=True): - super(FeedForward_Conv2d, self).__init__() - - self.bias = bias - - self.project_in = nn.Dense(dim, hidden_features, has_bias=bias) - - self.dwconv = nn.CellList( - [ - nn.Conv2d( - hidden_features, - hidden_features, - kernel_size=(5, 5), - stride=1, - padding=(2, 2), - dilation=1, - groups=hidden_features, - has_bias=bias, - pad_mode="pad", - ), - nn.Conv2d( - hidden_features, - hidden_features, - kernel_size=(3, 3), - stride=1, - padding=(1, 1), - dilation=1, - groups=hidden_features, - has_bias=bias, - pad_mode="pad", - ), - nn.Conv2d( - hidden_features, - hidden_features, - kernel_size=(1, 1), - stride=1, - padding=(0, 0), - dilation=1, - groups=hidden_features, - has_bias=bias, - pad_mode="pad", - ), - ] - ) - - self.project_out = nn.Dense(hidden_features, dim, has_bias=bias) - self.gelu = nn.GELU(approximate=False) - - def construct(self, x, t, h, w): - x = self.project_in(x) - b, _, d = x.shape - # b (t h w) d -> (b t) d h w - x = x.reshape(b, t, h, w, d).permute(0, 1, 4, 2, 3).reshape(b * t, d, h, w) - x = self.gelu(x) - out = x - for module in self.dwconv: - out = out + module(x) - # (b t) d h w -> b (t h w) d - d = out.shape[1] - out = out.reshape(b, t, d, h, w).permute(0, 1, 3, 4, 2).reshape(b, -1, d) - x = self.project_out(out) - return x - - class BasicTransformerBlock(nn.Cell): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - upcast_attention (`bool`, *optional*): - Whether to upcast the attention computation to float32. This is useful for mixed precision training. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether to use learnable elementwise affine parameters for normalization. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. - final_dropout (`bool` *optional*, defaults to False): - Whether to apply a final dropout after the last feed-forward layer. - attention_type (`str`, *optional*, defaults to `"default"`): - The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. - positional_embeddings (`str`, *optional*, defaults to `None`): - The type of positional embeddings to apply to. - num_positional_embeddings (`int`, *optional*, defaults to `None`): - The maximum number of positional embeddings to apply. - """ - def __init__( self, dim: int, @@ -1184,77 +538,29 @@ def __init__( dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' norm_eps: float = 1e-5, final_dropout: bool = False, - attention_type: str = "default", - positional_embeddings: Optional[str] = None, - num_positional_embeddings: Optional[int] = None, - ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, - ada_norm_bias: Optional[int] = None, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, - downsampler: str = None, - interpolation_scale_thw: Tuple[int] = (1, 1, 1), + interpolation_scale_thw: Tuple[int] = (1, 1, 1), + sparse1d: bool = False, + sparse_n: int = 2, + sparse_group: bool = False, attention_mode: str = "xformers", - use_rope: bool = False, FA_dtype=ms.bfloat16, ): super().__init__() - self.only_cross_attention = only_cross_attention - self.downsampler = downsampler - - # We keep these boolean flags for backward-compatibility. - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - self.use_ada_layer_norm_single = norm_type == "ada_norm_single" - self.use_layer_norm = norm_type == "layer_norm" - self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" self.FA_dtype = FA_dtype - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - self.norm_type = norm_type - self.num_embeds_ada_norm = num_embeds_ada_norm - if positional_embeddings and (num_positional_embeddings is None): - raise ValueError( - "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." - ) - - if positional_embeddings == "sinusoidal": - self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) - else: - self.pos_embed = None - # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn - if norm_type == "ada_norm": - self.norm1_ada = AdaLayerNorm(dim, num_embeds_ada_norm) - self.norm1_ada.norm = LayerNorm(dim, elementwise_affine=False) - elif norm_type == "ada_norm_zero": - self.norm1_ada_zero = AdaLayerNormZero(dim, num_embeds_ada_norm) - self.norm1_ada_zero.norm = LayerNorm(dim, elementwise_affine=False) - elif norm_type == "ada_norm_continuous": - self.norm1_ada_con = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "rms_norm", - ) - else: - self.norm1_ln = LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm1 = LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.attn1 = Attention( query_dim=dim, @@ -1264,100 +570,50 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, - attention_mode=attention_mode, out_bias=attention_out_bias, - downsampler=downsampler, - use_rope=use_rope, - interpolation_scale_thw=interpolation_scale_thw, + interpolation_scale_thw=interpolation_scale_thw, + sparse1d=sparse1d, + sparse_n=sparse_n, + sparse_group=sparse_group, + is_cross_attn=False, + attention_mode=attention_mode, FA_dtype=self.FA_dtype, ) # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - if norm_type == "ada_norm": - self.norm2_ada = AdaLayerNorm(dim, num_embeds_ada_norm) - self.norm2_ada.norm = LayerNorm(dim, elementwise_affine=False) - elif norm_type == "ada_norm_continuous": - self.norm2_ada_con = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "rms_norm", - ) - else: - self.norm2_ln = LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - out_bias=attention_out_bias, - downsampler=False, - use_rope=False, # do not position in cross attention - attention_mode=attention_mode, - FA_dtype=self.FA_dtype, - interpolation_scale_thw=interpolation_scale_thw, - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None + self.norm2 = LayerNorm(dim, norm_eps, norm_elementwise_affine) - # 3. Feed-forward - if norm_type == "ada_norm_continuous": - self.norm3 = AdaLayerNormContinuous( - dim, - ada_norm_continous_conditioning_embedding_dim, - norm_elementwise_affine, - norm_eps, - ada_norm_bias, - "layer_norm", - ) - elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]: - self.norm3 = LayerNorm(dim, norm_eps, norm_elementwise_affine) - elif norm_type == "layer_norm_i2vgen": - self.norm3 = None - if downsampler: - self.ff = FeedForward_Conv2d( - downsampler, - dim, - 2 * dim, - bias=ff_bias, - ) - else: - self.ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) - - # 4. Fuser - if attention_type == "gated" or attention_type == "gated-text-image": - self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + interpolation_scale_thw=interpolation_scale_thw, + sparse1d=sparse1d, + sparse_n=sparse_n, + sparse_group=sparse_group, + is_cross_attn=True, + attention_mode=attention_mode, + FA_dtype=self.FA_dtype, + ) # is self-attn if encoder_hidden_states is none - # 5. Scale-shift for PixArt-Alpha. - if self.use_ada_layer_norm_single: - self.scale_shift_table = ms.Parameter(ops.randn(6, dim) / dim**0.5) + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 + # 4. Scale-shift. + self.scale_shift_table = Parameter(ops.randn((6, dim)) / dim**0.5) - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim def construct( self, @@ -1366,130 +622,58 @@ def construct( encoder_hidden_states: Optional[ms.Tensor] = None, encoder_attention_mask: Optional[ms.Tensor] = None, timestep: Optional[ms.Tensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[ms.Tensor] = None, - frame: int = None, - height: int = None, - width: int = None, - added_cond_kwargs: Optional[Dict[str, ms.Tensor]] = None, + frame: int = None, + height: int = None, + width: int = None, ) -> ms.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention - batch_size = hidden_states.shape[0] - gate_msa, shift_mlp, scale_mlp, gate_mlp = None, None, None, None - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm1_ada(hidden_states, timestep) - elif self.norm_type == "ada_norm_zero": - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1_ada_zero( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm1_ln(hidden_states) - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm1_ada_con(hidden_states, added_cond_kwargs["pooled_text_emb"]) - - elif self.norm_type == "ada_norm_single": - if get_sequence_parallel_state(): - batch_size = hidden_states.shape[1] # S B H - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( - self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1), 6, dim=0 - ) - else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1), 6, dim=1 - ) - norm_hidden_states = self.norm1_ln(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - # norm_hidden_states = norm_hidden_states.squeeze(1) # error message + if get_sequence_parallel_state(): + batch_size = hidden_states.shape[1] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( + self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1), 6, dim=0) else: - raise ValueError("Incorrect norm used") + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1), 6, dim=1 + ) - if self.pos_embed is not None: - norm_hidden_states = self.pos_embed(norm_hidden_states) + norm_hidden_states = self.norm1(hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - # 1. Prepare GLIGEN inputs - cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} - if "gligen" in cross_attention_kwargs: - gligen_kwargs = cross_attention_kwargs["gligen"] - else: - gligen_kwargs = None attn_output = self.attn1( norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - frame=frame, - height=height, - width=width, - **cross_attention_kwargs, + encoder_hidden_states=None, + attention_mask=attention_mask, frame=frame, height=height, width=width, ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - elif self.use_ada_layer_norm_single: - attn_output = gate_msa * attn_output + + attn_output = gate_msa * attn_output hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) - # 1.2 GLIGEN Control - if gligen_kwargs is not None: - hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - # 3. Cross-Attention - if self.attn2 is not None: - if self.norm_type == "ada_norm": - norm_hidden_states = self.norm2_ada(hidden_states, timestep) - elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: - norm_hidden_states = self.norm2_ln(hidden_states) - elif self.norm_type == "ada_norm_single": - # For PixArt norm2 isn't applied here: - # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 - norm_hidden_states = hidden_states - elif self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm2_ada_con(hidden_states, added_cond_kwargs["pooled_text_emb"]) - else: - raise ValueError("Incorrect norm") + norm_hidden_states = hidden_states - if self.pos_embed is not None and self.norm_type != "ada_norm_single": - norm_hidden_states = self.pos_embed(norm_hidden_states) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, frame=frame, height=height, width=width, + ) - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states + hidden_states = attn_output + hidden_states # 4. Feed-forward - if self.norm_type == "ada_norm_continuous": - norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif not self.norm_type == "ada_norm_single": - norm_hidden_states = self.norm3(hidden_states) - - if self.norm_type == "ada_norm_zero": - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = self.norm2(hidden_states) - if self.norm_type == "ada_norm_single": - norm_hidden_states = self.norm2_ln(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - if self.downsampler: - ff_output = self.ff(norm_hidden_states, t=frame, h=height, w=width) - else: - ff_output = self.ff(norm_hidden_states) + ff_output = self.ff(norm_hidden_states) - if self.norm_type == "ada_norm_zero": - ff_output = gate_mlp.unsqueeze(1) * ff_output - elif self.norm_type == "ada_norm_single": - ff_output = gate_mlp * ff_output + ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states - if hidden_states.ndim == 4: - hidden_states = hidden_states.squeeze(1) - return hidden_states + return hidden_states \ No newline at end of file diff --git a/examples/opensora_pku/opensora/sample/pipeline_opensora.py b/examples/opensora_pku/opensora/sample/pipeline_opensora.py index 8a4c99e5cf..a43eddec01 100644 --- a/examples/opensora_pku/opensora/sample/pipeline_opensora.py +++ b/examples/opensora_pku/opensora/sample/pipeline_opensora.py @@ -4,7 +4,7 @@ import math import re import urllib.parse as ul -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union, Dict from opensora.acceleration.communications import AllGather from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info @@ -13,7 +13,13 @@ from mindspore import mint, ops from mindone.diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from mindone.diffusers.utils import BACKENDS_MAPPING, deprecate, is_bs4_available, is_ftfy_available +from mindone.diffusers.utils import BACKENDS_MAPPING, deprecate, is_bs4_available, is_ftfy_available, BaseOutput +from mindone.diffusers import AutoencoderKL +from mindone.diffusers import DDPMScheduler, FlowMatchEulerDiscreteScheduler +# from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback #TODO:TBD + +from mindone.transformers import CLIPTextModelWithProjection, T5EncoderModel +from transformers import CLIPTokenizer, CLIPImageProcessor, MT5Tokenizer logger = logging.getLogger(__name__) @@ -22,6 +28,31 @@ if is_ftfy_available(): import ftfy + +from dataclasses import dataclass +import numpy as np +import PIL + +from examples.opensora_pku.opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V_v1_3 + + +@dataclass +class OpenSoraPipelineOutput(BaseOutput): + videos: Union[List[ms.Tensor], List[PIL.Image.Image], np.ndarray] + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), ddof=True, keepdims=True) + std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), ddof=True, keepdims=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps @@ -29,6 +60,7 @@ def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -42,14 +74,18 @@ def retrieve_timesteps( The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + `Tuple[ms.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -60,6 +96,16 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, **kwargs) timesteps = scheduler.timesteps @@ -67,72 +113,63 @@ def retrieve_timesteps( class OpenSoraPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using PixArt-Sigma. - """ - bad_punct_regex = re.compile( - r"[" - + "#®•©™&@·º½¾¿¡§~" - + r"\)" - + r"\(" - + r"\]" - + r"\[" - + r"\}" - + r"\{" - + r"\|" - + r"\\" - + r"\/" - + r"\*" - + r"]{1,}" - ) # noqa - - _optional_components = ["tokenizer", "text_encoder"] - model_cpu_offload_seq = "text_encoder->transformer->vae" + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] def __init__( self, - transformer, - tokenizer, - text_encoder, - vae, - scheduler, + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: MT5Tokenizer, + transformer: OpenSoraT2V_v1_3, + scheduler: DDPMScheduler, + text_encoder_2: CLIPTextModelWithProjection = None, + tokenizer_2: CLIPTokenizer = None, ): super().__init__() self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, ) - # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.all_gather = None if not get_sequence_parallel_state() else AllGather() - # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py - def mask_text_embeddings(self, emb, mask): - if emb.shape[0] == 1: - keep_index = int(mask.sum()) - return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096 - else: - masked_feature = emb * mask[:, None, :, None] # 1 120 4096 - return masked_feature, emb.shape[2] @ms.jit # FIXME: on ms2.3, in pynative mode, text encoder's output has nan problem. - def text_encoding_func(self, input_ids, attention_mask): - return ops.stop_gradient(self.text_encoder(input_ids, attention_mask=attention_mask)) + def text_encoding_func(self, text_encoder, input_ids, attention_mask): + return ops.stop_gradient(text_encoder(input_ids, attention_mask=attention_mask)) - # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt def encode_prompt( self, - prompt: Union[str, List[str]], + prompt: str, + dtype = None, + num_samples_per_prompt: int = 1, do_classifier_free_guidance: bool = True, - negative_prompt: str = "", - num_images_per_prompt: int = 1, + negative_prompt: Optional[str] = None, prompt_embeds: Optional[ms.Tensor] = None, negative_prompt_embeds: Optional[ms.Tensor] = None, prompt_attention_mask: Optional[ms.Tensor] = None, negative_prompt_attention_mask: Optional[ms.Tensor] = None, - clean_caption: bool = False, - max_sequence_length: int = 120, - **kwargs, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, ): r""" Encodes the prompt into text encoder hidden states. @@ -140,112 +177,140 @@ def encode_prompt( Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` - instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For - PixArt-Alpha, this should be "". - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - whether to use classifier free guidance or not - num_images_per_prompt (`int`, *optional*, defaults to 1): + dtype (`ms.dtype`): + mindspore dtype + num_samples_per_prompt (`int`): number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). prompt_embeds (`ms.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`ms.Tensor`, *optional*): - Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" - string. - clean_caption (`bool`, defaults to `False`): - If `True`, the function will preprocess and clean the provided caption before encoding. - max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`ms.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`ms.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for T5 and `1` for clip. """ - if "mask_feature" in kwargs: - deprecation_message = ( - "The use of `mask_feature` is deprecated. It is no longer used in any computation and " - "that doesn't affect the end results. It will be removed in a future version." - ) - deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 512 + if text_encoder_index == 1: + max_length = 77 + else: + max_length = max_sequence_length if prompt is not None and isinstance(prompt, str): batch_size = 1 + prompt = [prompt] elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - # See Section 3.1. of the paper. - max_length = max_sequence_length - if prompt_embeds is None: - prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) - text_inputs = self.tokenizer( + text_inputs = tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, - add_special_tokens=True, + # return_attention_mask=True, return_tensors=None, ) text_input_ids = ms.Tensor(text_inputs.input_ids) - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors=None).input_ids - untruncated_ids = ms.Tensor(untruncated_ids) + untruncated_ids = ms.Tensor(tokenizer(prompt, padding="longest", return_tensors=None).input_ids) if ( - untruncated_ids.shape[-1] >= text_input_ids.shape[-1] - and not ops.equal(text_input_ids, untruncated_ids[:, : text_input_ids.shape[-1]]).all() + untruncated_ids.shape[-1] >= text_input_ids.shape[-1] + and not ops.equal(text_input_ids, untruncated_ids[:, :text_input_ids.shape[-1]]).all() ): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( - "The following part of your input was truncated because the model can only handle sequences up to" - f" {max_length} tokens: {removed_text}" + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_attention_mask = ms.Tensor(text_inputs.attention_mask) - - prompt_embeds = self.text_encoding_func(text_input_ids, attention_mask=prompt_attention_mask) + prompt_embeds = self.text_encoding_func(text_encoder, text_input_ids, attention_mask=prompt_attention_mask) prompt_embeds = prompt_embeds[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds - else: - prompt_attention_mask = ops.ones_like(prompt_embeds) - if self.text_encoder is not None: - dtype = self.text_encoder.dtype - elif self.transformer is not None: - dtype = self.transformer.dtype + if text_encoder_index == 1: + prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d for clip + + prompt_attention_mask = prompt_attention_mask.repeat(num_samples_per_prompt, axis=0) else: - dtype = None + prompt_attention_mask = ops.ones_like(prompt_embeds) prompt_embeds = prompt_embeds.to(dtype=dtype) + bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 0) + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(num_samples_per_prompt, axis=1) + prompt_embeds = prompt_embeds.view((bs_embed * num_samples_per_prompt, seq_len, -1)) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = [negative_prompt] * batch_size - uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + # elif prompt is not None and type(prompt) is not type(negative_prompt): + # raise TypeError( + # f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + # f" {type(prompt)}." + # ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, - return_attention_mask=True, - add_special_tokens=True, return_tensors=None, ) + + uncond_text_inputs = ms.Tensor(uncond_input.input_ids) negative_prompt_attention_mask = ms.Tensor(uncond_input.attention_mask) - negative_prompt_embeds = self.text_encoding_func( - ms.Tensor(uncond_input.input_ids), - attention_mask=negative_prompt_attention_mask, - ) - negative_prompt_embeds = ( - negative_prompt_embeds[0] - if isinstance(negative_prompt_embeds, (list, tuple)) - else negative_prompt_embeds - ) + negative_prompt_embeds = self.text_encoding_func(text_encoder, uncond_text_inputs, attention_mask=negative_prompt_attention_mask) + negative_prompt_embeds = negative_prompt_embeds[0] if isinstance(negative_prompt_embeds, (list, tuple)) else negative_prompt_embeds + + if text_encoder_index == 1: + negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1) # b d -> b 1 d for clip + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_samples_per_prompt, axis=0) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method @@ -253,16 +318,13 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype) - negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 0) + negative_prompt_embeds = negative_prompt_embeds.repeat(num_samples_per_prompt, axis = 1) + negative_prompt_embeds = negative_prompt_embeds.view((batch_size * num_samples_per_prompt, seq_len, -1)) else: negative_prompt_embeds = None negative_prompt_attention_mask = None - return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -288,24 +350,27 @@ def check_inputs( num_frames, height, width, - negative_prompt, - callback_steps, + negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, prompt_attention_mask=None, negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, ): - if num_frames <= 0: - raise ValueError(f"`num_frames` have to be positive but is {num_frames}.") + if (num_frames - 1) % 4 != 0: + raise ValueError(f"`num_frames - 1` have to be divisible by 4 but is {num_frames}.") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -317,19 +382,18 @@ def check_inputs( raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) if prompt_embeds is not None and prompt_attention_mask is None: raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") - if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: - raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -337,6 +401,13 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( @@ -344,181 +415,37 @@ def check_inputs( f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) - if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: raise ValueError( - "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" - f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" - f" {negative_prompt_attention_mask.shape}." + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." ) - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing - def _text_preprocessing(self, text, clean_caption=False): - if clean_caption and not is_bs4_available(): - logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) - logger.warn("Setting `clean_caption` to False...") - clean_caption = False - - if clean_caption and not is_ftfy_available(): - logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) - logger.warn("Setting `clean_caption` to False...") - clean_caption = False - - if not isinstance(text, (tuple, list)): - text = [text] - - def process(text: str): - if clean_caption: - text = self._clean_caption(text) - text = self._clean_caption(text) - else: - text = text.lower().strip() - return text - - return [process(t) for t in text] - - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption - def _clean_caption(self, caption): - caption = str(caption) - caption = ul.unquote_plus(caption) - caption = caption.strip().lower() - caption = re.sub("", "person", caption) - # urls: - caption = re.sub( - r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", - # noqa - "", - caption, - ) # regex for urls - caption = re.sub( - r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", - # noqa - "", - caption, - ) # regex for urls - # html: - caption = BeautifulSoup(caption, features="html.parser").text - - # @ - caption = re.sub(r"@[\w\d]+\b", "", caption) - - # 31C0—31EF CJK Strokes - # 31F0—31FF Katakana Phonetic Extensions - # 3200—32FF Enclosed CJK Letters and Months - # 3300—33FF CJK Compatibility - # 3400—4DBF CJK Unified Ideographs Extension A - # 4DC0—4DFF Yijing Hexagram Symbols - # 4E00—9FFF CJK Unified Ideographs - caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) - caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) - caption = re.sub(r"[\u3200-\u32ff]+", "", caption) - caption = re.sub(r"[\u3300-\u33ff]+", "", caption) - caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) - caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) - # caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) - ####################################################### - - # все виды тире / all types of dash --> "-" - caption = re.sub( - r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", - # noqa - "-", - caption, - ) - - # кавычки к одному стандарту - caption = re.sub(r"[`´«»“”¨]", '"', caption) - caption = re.sub(r"[‘’]", "'", caption) - - # " - caption = re.sub(r""?", "", caption) - # & - caption = re.sub(r"&", "", caption) - - # ip adresses: - caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) - - # article ids: - caption = re.sub(r"\d:\d\d\s+$", "", caption) - - # \n - caption = re.sub(r"\\n", " ", caption) - - # "#123" - caption = re.sub(r"#\d{1,3}\b", "", caption) - # "#12345.." - caption = re.sub(r"#\d{5,}\b", "", caption) - # "123456.." - caption = re.sub(r"\b\d{6,}\b", "", caption) - # filenames: - caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) - - # - caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" - caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" - - caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT - caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " - - # this-is-my-cute-cat / this_is_my_cute_cat - regex2 = re.compile(r"(?:\-|\_)") - if len(re.findall(regex2, caption)) > 3: - caption = re.sub(regex2, " ", caption) - - caption = ftfy.fix_text(caption) - caption = html.unescape(html.unescape(caption)) - - caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 - caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc - caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 - - caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) - caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) - caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) - caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) - caption = re.sub(r"\bpage\s+\d+\b", "", caption) - - caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... - - caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) - - caption = re.sub(r"\b\s+\:\s+", r": ", caption) - caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) - caption = re.sub(r"\s+", " ", caption) - - caption.strip() - - caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) - caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) - caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) - caption = re.sub(r"^\.\S+$", "", caption) - - return caption.strip() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents( - self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None - ): + def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None): shape = ( batch_size, num_channels_latents, - (math.ceil((int(num_frames) - 1) / self.vae.vae_scale_factor[0]) + 1) - if int(num_frames) % 2 == 1 - else math.ceil(int(num_frames) / self.vae.vae_scale_factor[0]), - math.ceil(int(height) / self.vae.vae_scale_factor[1]), - math.ceil(int(width) / self.vae.vae_scale_factor[2]), + (int(num_frames) - 1) // self.vae.vae_scale_factor[0] + 1, + int(height) // self.vae.vae_scale_factor[1], + int(width) // self.vae.vae_scale_factor[2], ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + if latents is None: - latents = ops.randn(shape, dtype=dtype) + latents = ops.randn(shape, dtype=dtype) #generator=generator else: latents = latents - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma return latents def prepare_parallel_latent(self, video_states): @@ -529,136 +456,96 @@ def prepare_parallel_latent(self, video_states): if padding_needed > 0: logger.debug("Doing video padding") # B, C, T, H, W -> B, C, T', H, W - video_states = ops.pad(video_states, (0, 0, 0, 0, 0, padding_needed), mode="constant", value=0) + video_states = mint.nn.functional.pad(video_states, (0, 0, 0, 0, 0, padding_needed), mode="constant", value=0) b, _, f, h, w = video_states.shape - temp_attention_mask = ops.ones((b, f), ms.int32) + temp_attention_mask = mint.ones((b, f), ms.int32) temp_attention_mask[:, -padding_needed:] = 0 assert video_states.shape[2] % sp_size == 0 video_states = ops.chunk(video_states, sp_size, 2)[index] return video_states, temp_attention_mask + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt def __call__( self, prompt: Union[str, List[str]] = None, - negative_prompt: str = "", - num_inference_steps: int = 20, - timesteps: List[int] = None, - guidance_scale: float = 4.5, - num_images_per_prompt: Optional[int] = 1, num_frames: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, - eta: float = 0.0, - generator=None, + num_inference_steps: Optional[int] = 50, + timesteps: List[int] = None, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_samples_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator = None, latents: Optional[ms.Tensor] = None, prompt_embeds: Optional[ms.Tensor] = None, - prompt_attention_mask: Optional[ms.Tensor] = None, + prompt_embeds_2: Optional[ms.Tensor] = None, negative_prompt_embeds: Optional[ms.Tensor] = None, + negative_prompt_embeds_2: Optional[ms.Tensor] = None, + prompt_attention_mask: Optional[ms.Tensor] = None, + prompt_attention_mask_2: Optional[ms.Tensor] = None, negative_prompt_attention_mask: Optional[ms.Tensor] = None, + negative_prompt_attention_mask_2: Optional[ms.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback: Optional[Callable[[int, int, ms.Tensor], None]] = None, - callback_steps: int = 1, - clean_caption: bool = True, - use_resolution_binning: bool = True, - max_sequence_length: int = 300, - **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. + callback_on_step_end: Optional[Callable[[int, int, ms.Tensor], None]] = None, # Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + max_sequence_length: int = 512 + ): + # TODO + # if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + # callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - num_inference_steps (`int`, *optional*, defaults to 100): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` - timesteps are used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. - height (`int`, *optional*, defaults to self.unet.config.sample_size): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size): - The width in pixels of the generated image. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`ms.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`ms.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - prompt_attention_mask (`ms.Tensor`, *optional*): Pre-generated attention mask for text embeddings. - negative_prompt_embeds (`ms.Tensor`, *optional*): - Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not - provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_prompt_attention_mask (`ms.Tensor`, *optional*): - Pre-generated attention mask for negative text embeddings. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: ms.Tensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - clean_caption (`bool`, *optional*, defaults to `True`): - Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to - be installed. If the dependencies are not installed, the embeddings will be created from the raw - prompt. - use_resolution_binning (`bool` defaults to `True`): - If set to `True`, the requested height and width are first mapped to the closest resolutions using - `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to - the requested resolution. Useful for generating non-square images. - max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. - - Examples: - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is - returned where the first element is a list with the generated images - """ - # 1. Check inputs. Raise error if not correct - num_frames = num_frames or self.transformer.config.sample_size_t * self.vae.vae_scale_factor[0] + # 0. default height and width + num_frames = num_frames or (self.transformer.config.sample_size_t - 1) * self.vae.vae_scale_factor[0] + 1 height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2] + + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, - num_frames, + num_frames, height, width, negative_prompt, - callback_steps, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False - # 2. Default height and width to transformer + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -666,44 +553,71 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + ( prompt_embeds, - prompt_attention_mask, negative_prompt_embeds, + prompt_attention_mask, negative_prompt_attention_mask, ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, + prompt=prompt, + dtype=self.transformer.dtype, + num_samples_per_prompt=num_samples_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, - clean_caption=clean_caption, max_sequence_length=max_sequence_length, + text_encoder_index=0, ) - - if do_classifier_free_guidance: - prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds], axis=0) - prompt_attention_mask = ops.cat([negative_prompt_attention_mask, prompt_attention_mask], axis=0) + if self.tokenizer_2 is not None: + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + dtype=self.transformer.dtype, + num_samples_per_prompt=num_samples_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=77, + text_encoder_index=1, + ) + else: + prompt_embeds_2 = None + negative_prompt_embeds_2 = None + prompt_attention_mask_2 = None + negative_prompt_attention_mask_2 = None # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) - # 5. Prepare latents. - latent_channels = self.transformer.config.in_channels - world_size = hccl_info.world_size + # 5. Prepare latent variables + if get_sequence_parallel_state(): + world_size = hccl_info.world_size + num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( - batch_size * num_images_per_prompt, - latent_channels, - num_frames, + batch_size * num_samples_per_prompt, + num_channels_latents, + (num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames, height, width, prompt_embeds.dtype, @@ -712,19 +626,33 @@ def __call__( ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + else: + extra_step_kwargs = {} - # 6.1 Prepare micro-conditions. - added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # 7 create image_rotary_emb, style embedding & time ids + if self.do_classifier_free_guidance: + prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds], axis=0) + prompt_attention_mask = ops.cat([negative_prompt_attention_mask, prompt_attention_mask], axis=0) + if self.tokenizer_2 is not None: + prompt_embeds_2 = ops.cat([negative_prompt_embeds_2, prompt_embeds_2], axis=0) + prompt_attention_mask_2 = ops.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2], axis=0) - # 7. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - world_size = hccl_info.world_size + # ==================make sp===================================== if get_sequence_parallel_state(): + # # b (n x) h -> b n x h + # b, _, h = prompt_embeds.shape + # n = world_size + # x = prompt_embeds.shape[1] // world_size + # prompt_embeds = prompt_embeds.reshape(b, n, x, h).contiguous() + # rank = hccl_info.rank + # prompt_embeds = prompt_embeds[:, rank, :, :] + latents, temp_attention_mask = self.prepare_parallel_latent(latents) temp_attention_mask = ( ops.cat([temp_attention_mask] * 2) - if (do_classifier_free_guidance and temp_attention_mask is not None) + if (self.do_classifier_free_guidance and temp_attention_mask is not None) else temp_attention_mask ) # b (n x) h -> b n x h @@ -735,86 +663,126 @@ def __call__( prompt_embeds = prompt_embeds[:, index, :, :] else: temp_attention_mask = None + # ==================make sp===================================== + # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = ops.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - + if self.interrupt: + continue + + # print(f"t {t}:, latents {latents.shape}") + # expand the latents if we are doing classifier free guidance + latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input current_timestep = t if not isinstance(current_timestep, ms.Tensor): - if isinstance(current_timestep, float): - dtype = ms.float32 - else: - dtype = ms.int32 - current_timestep = ms.Tensor([current_timestep], dtype=dtype) + current_timestep = ms.Tensor([current_timestep], dtype=latent_model_input.dtype) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None] - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.repeat_interleave(latent_model_input.shape[0], 0) + + # ==================prepare my shape===================================== + # predict the noise residual if prompt_embeds.ndim == 3: prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d if prompt_attention_mask.ndim == 2: prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l - # b c t h w -> b t h w - attention_mask = ops.ones_like(latent_model_input)[:, 0] # b t h w + if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2: + prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d + + attention_mask = ops.ones_like(latent_model_input)[:, 0] if temp_attention_mask is not None: # temp_attention_mask shape (bs, t), 1 means to keep, 0 means to discard # TODO: mask temporal padded tokens attention_mask = ( attention_mask.to(ms.int32) * temp_attention_mask[:, :, None, None].to(ms.int32) ).to(ms.bool_) + # ==================prepare my shape===================================== + + # ==================make sp===================================== if get_sequence_parallel_state(): - attention_mask = attention_mask.tile((1, world_size, 1, 1)) # b t*sp_size h w - # predict noise model_output + attention_mask = attention_mask.repeat(world_size, axis = 1) + # ==================make sp===================================== + noise_pred = ops.stop_gradient( - self.transformer( - latent_model_input, # (b c t h w) - attention_mask=attention_mask, - encoder_hidden_states=prompt_embeds, # (b n c) - encoder_attention_mask=prompt_attention_mask, # (b n) - timestep=current_timestep, # (b) - added_cond_kwargs=added_cond_kwargs, + self.transformer( + latent_model_input, # (b c t h w) + attention_mask=attention_mask, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + pooled_projections=prompt_embeds_2, + return_dict=False, ) - ) - + ) # b,c,t,h,w + assert not ops.any(ops.isnan(noise_pred.float())) # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # learned sigma - if self.transformer.config.out_channels // 2 == latent_channels: - noise_pred = noise_pred.chunk(2, axis=1)[0] # b c t h w - else: - noise_pred = noise_pred + if self.do_classifier_free_guidance and guidance_rescale > 0.0 and not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - # compute previous image: x_t -> x_t-1 + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) + # ==================make sp===================================== if get_sequence_parallel_state(): - sp_size = hccl_info.world_size - latents = self.all_gather(latents) - latents_list = mint.chunk(latents, sp_size, 0) - latents = ops.concat(latents_list, axis=2)[:, :, :num_frames] + # latents_shape = list(latents.shape) # b c t//sp h w + # full_shape = [latents_shape[0] * world_size] + latents_shape[1:] # # b*sp c t//sp h w + # all_latents = ops.zeros(full_shape, dtype=latents.dtype) + all_latents = self.all_gather(latents) + latents_list = mint.chunk(all_latents, world_size, axis = 0) + latents = ops.cat(latents_list, axis=2) + # ==================make sp===================================== if not output_type == "latents": - image = self.decode_latents(latents) - image = image[:, :num_frames, :height, :width] + videos = self.decode_latents(latents) + videos = videos[:, :num_frames, :height, :width] else: - image = latents - if not return_dict: - return (image,) + videos = latents - return ImagePipelineOutput(images=image) + # Offload all models + # self.maybe_free_model_hooks() + if not return_dict: + return (videos, ) + + return OpenSoraPipelineOutput(videos=videos) + # TODO: update OpenSoraPipelineOutput according to VAE usage + + + # def decode_latents(self, latents): + # print(f'before vae decode {latents.shape}', ops.max(latents).item(), ops.min(latents).item(), ops.mean(latents).item(), ops.std(latents).item()) + # video = self.vae.decode(latents.to(self.vae.vae.dtype)) + # print(f'after vae decode {latents.shape}', ops.max(video).item(), ops.min(video).item(), ops.mean(video).item(), ops.std(video).item()) + # video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=ms.uint8).permute(0, 1, 3, 4, 2).contiguous() # b t h w c + # return video + + #TODO: TBD depending on VAE usage def decode_latents_per_sample(self, latents): # video = self.vae.decode(latents) video = self.vae.decode(latents).to(ms.float32) # (b t c h w) diff --git a/examples/opensora_pku/opensora/sample/sample_t2v.py b/examples/opensora_pku/opensora/sample/sample.py similarity index 51% rename from examples/opensora_pku/opensora/sample/sample_t2v.py rename to examples/opensora_pku/opensora/sample/sample.py index 3f0204438a..47a5ff8e6f 100644 --- a/examples/opensora_pku/opensora/sample/sample_t2v.py +++ b/examples/opensora_pku/opensora/sample/sample.py @@ -14,34 +14,36 @@ import mindspore as ms from mindspore import nn + # TODO: remove in future when mindone is ready for install mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) sys.path.append(os.path.abspath("./")) +from opensora.npu_config import npu_config from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info from opensora.dataset.text_dataset import create_dataloader -from opensora.models.causalvideovae import CausalVAEModelWrapper, ae_stride_config -from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate -from opensora.models.diffusion.opensora.modeling_opensora import LayerNorm, OpenSoraT2V -from opensora.models.diffusion.opensora.modules import Attention -from opensora.sample.pipeline_opensora import OpenSoraPipeline from opensora.utils.message_utils import print_banner from opensora.utils.ms_utils import init_env from opensora.utils.utils import _check_cfgs_in_parser, get_precision -from transformers import AutoTokenizer +from opensora.models.causalvideovae import ae_stride_config, ae_wrapper +# from opensora.sample.caption_refiner import OpenSoraCaptionRefiner #TODO +from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate +from examples.opensora_pku.opensora.models.diffusion.opensora.modeling_opensora import LayerNorm, OpenSoraT2V_v1_3 +from examples.opensora_pku.opensora.models.diffusion.opensora.modules import Attention +from opensora.sample.pipeline_opensora import OpenSoraPipeline from mindone.diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings -from mindone.diffusers.schedulers import ( - DDIMScheduler, - DDPMScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - PNDMScheduler, -) -from mindone.transformers import MT5EncoderModel - -# from mindone.transformers.activations import NewGELUActivation -# from mindone.transformers.models.mt5.modeling_mt5 import MT5LayerNorm +from mindone.diffusers import ( + DDIMScheduler, DDPMScheduler, PNDMScheduler, + EulerDiscreteScheduler, DPMSolverMultistepScheduler, + HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, + DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler, + DPMSolverSinglestepScheduler, #CogVideoXDDIMScheduler, + FlowMatchEulerDiscreteScheduler + ) +from mindone.transformers import T5EncoderModel, MT5EncoderModel, CLIPTextModelWithProjection +from transformers import AutoTokenizer, MT5Tokenizer + from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import str2bool from mindone.utils.logger import set_logger @@ -50,9 +52,76 @@ logger = logging.getLogger(__name__) +# ms.set_context(pynative_synchronize=True) + +# Copied from opensora.utils sample_utils.py +def get_scheduler(args): + kwargs = dict( + prediction_type=args.prediction_type, + rescale_betas_zero_snr=args.rescale_betas_zero_snr, + timestep_spacing="trailing" if args.rescale_betas_zero_snr else 'leading', + ) + if args.v1_5_scheduler: + kwargs['beta_start'] = 0.00085 + kwargs['beta_end'] = 0.0120 + kwargs['beta_schedule'] = "scaled_linear" + if args.sample_method == 'DDIM': + scheduler_cls = DDIMScheduler + kwargs['clip_sample'] = False + elif args.sample_method == 'EulerDiscrete': + scheduler_cls = EulerDiscreteScheduler + elif args.sample_method == 'DDPM': + scheduler_cls = DDPMScheduler + kwargs['clip_sample'] = False + elif args.sample_method == 'DPMSolverMultistep': + scheduler_cls = DPMSolverMultistepScheduler + elif args.sample_method == 'DPMSolverSinglestep': + scheduler_cls = DPMSolverSinglestepScheduler + elif args.sample_method == 'PNDM': + scheduler_cls = PNDMScheduler + kwargs.pop('rescale_betas_zero_snr', None) + elif args.sample_method == 'HeunDiscrete': ######## + scheduler_cls = HeunDiscreteScheduler + elif args.sample_method == 'EulerAncestralDiscrete': + scheduler_cls = EulerAncestralDiscreteScheduler + elif args.sample_method == 'DEISMultistep': + scheduler_cls = DEISMultistepScheduler + kwargs.pop('rescale_betas_zero_snr', None) + elif args.sample_method == 'KDPM2AncestralDiscrete': ######### + scheduler_cls = KDPM2AncestralDiscreteScheduler + # elif args.sample_method == 'CogVideoX': + # scheduler_cls = CogVideoXDDIMScheduler + elif args.sample_method == 'FlowMatchEulerDiscrete': + scheduler_cls = FlowMatchEulerDiscreteScheduler + kwargs = {} + else: + raise NameError(f'Unsupport sample_method {args.sample_method}') + scheduler = scheduler_cls(**kwargs) + return scheduler def parse_args(): parser = argparse.ArgumentParser() + + parser.add_argument("--version", type=str, default='v1_3', choices=['v1_3', 'v1_5']) + parser.add_argument("--caption_refiner", type=str, default=None) + parser.add_argument("--enhance_video", type=str, default=None) + parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl', help="google/mt5-xxl, DeepFloyd/t5-v1_1-xxl") + parser.add_argument("--text_encoder_name_2", type=str, default=None, help=" openai/clip-vit-large-patch14, (laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)") + parser.add_argument("--num_samples_per_prompt", type=int, default=1) + parser.add_argument('--refine_caption', action='store_true') + # parser.add_argument('--compile', action='store_true') + parser.add_argument("--prediction_type", type=str, default='epsilon', help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.") + parser.add_argument('--rescale_betas_zero_snr', action='store_true') + # parser.add_argument('--local_rank', type=int, default=-1) + # parser.add_argument('--world_size', type=int, default=1) + # parser.add_argument('--sp', action='store_true') + + parser.add_argument('--v1_5_scheduler', action='store_true') + parser.add_argument('--conditional_pixel_values_path', type=str, default=None) + parser.add_argument('--mask_type', type=str, default=None) + parser.add_argument('--crop_for_hw', action='store_true') + parser.add_argument('--max_hxw', type=int, default=236544) # 480*480 + parser.add_argument( "--config", "-c", @@ -60,7 +129,7 @@ def parse_args(): type=str, help="path to load a config yaml file that describes the setting which will override the default arguments", ) - parser.add_argument("--model_path", type=str, default="LanguageBind/Open-Sora-Plan-v1.2.0") + parser.add_argument("--model_path", type=str, default="LanguageBind/Open-Sora-Plan-v1.3.0") parser.add_argument( "--ms_checkpoint", type=str, @@ -79,7 +148,7 @@ def parse_args(): parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") parser.add_argument("--guidance_scale", type=float, default=7.5, help="the scale for classifier-free guidance") - parser.add_argument("--max_sequence_length", type=int, default=300, help="the maximum text tokens length") + parser.add_argument("--max_sequence_length", type=int, default=512, help="the maximum text tokens length") parser.add_argument("--sample_method", type=str, default="PNDM") parser.add_argument("--num_sampling_steps", type=int, default=50, help="Diffusion Sampling Steps") @@ -109,7 +178,7 @@ def parse_args(): parser.add_argument( "--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax" ) - parser.add_argument("--seed", type=int, default=4, help="Inference seed") + parser.add_argument("--seed", type=int, default=42, help="Inference seed") parser.add_argument( "--precision", @@ -166,7 +235,7 @@ def parse_args(): parser.add_argument( "--video_extension", default="mp4", choices=["gif", "mp4"], help="The file extension to save videos" ) - parser.add_argument("--model_type", type=str, default="dit", choices=["dit", "udit", "latte"]) + parser.add_argument("--model_type", type=str, default="dit", choices=["dit", "udit", "latte", 't2v', 'inpaint', 'i2v']) parser.add_argument("--cache_dir", type=str, default="./") parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not") default_args = parser.parse_args() @@ -179,134 +248,93 @@ def parse_args(): _check_cfgs_in_parser(cfg, parser) parser.set_defaults(**cfg) args = parser.parse_args() - return args + assert not (args.use_parallel and args.num_frames == 1) -if __name__ == "__main__": - # 1. init env - args = parse_args() - save_dir = args.save_img_path - os.makedirs(save_dir, exist_ok=True) - set_logger(name="", output_dir=save_dir) - # 1. init - rank_id, device_num = init_env( - args.mode, - seed=args.seed, - distributed=args.use_parallel, - device_target=args.device, - max_device_memory=args.max_device_memory, - parallel_mode=args.parallel_mode, - precision_mode=args.precision_mode, - global_bf16=args.global_bf16, - sp_size=args.sp_size, - jit_level=args.jit_level, - jit_syntax_level=args.jit_syntax_level, - ) + return args - # 2. vae model initiate and weight loading + +def prepare_pipeline(args): + # VAE model initiate and weight loading print_banner("vae init") - vae = CausalVAEModelWrapper(args.ae_path, cache_dir=args.cache_dir, use_safetensors=True) + vae_dtype = get_precision(args.vae_precision) + if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): + logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") + state_dict = ms.load_checkpoint(args.ms_checkpoint) + # rm 'network.' prefix + state_dict = dict( + [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() + ) + else: + state_dict = None + kwarg = { + "state_dict": state_dict, + "use_safetensors": True, + "dtype": vae_dtype, + } + vae = ae_wrapper[args.ae](args.ae_path, **kwarg) + vae.vae_scale_factor = ae_stride_config[args.ae] if args.enable_tiling: vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor - vae.vae.tile_sample_min_size = 512 - vae.vae.tile_latent_min_size = 64 - vae.vae.tile_sample_min_size_t = 29 - vae.vae.tile_latent_min_size_t = 8 - if args.save_memory: - vae.vae.tile_sample_min_size = 256 - vae.vae.tile_latent_min_size = 32 - vae.vae.tile_sample_min_size_t = 29 - vae.vae.tile_latent_min_size_t = 8 - vae.vae_scale_factor = ae_stride_config[args.ae] - # use amp level O2 for causal 3D VAE with bfloat16 or float16 - vae_dtype = get_precision(args.vae_precision) + + ## use amp level O2 for causal 3D VAE with bfloat16 or float16 if vae_dtype == ms.float16: custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] else: custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] - vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells) logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}") + vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells) + vae.set_train(False) for param in vae.get_parameters(): # freeze vae - param.requires_grad = False - - # 3. handle input text prompts - print_banner("text prompts loading") - ext = ( - f"{args.video_extension}" if not (args.save_latents or args.decode_latents) else "npy" - ) # save video as gif or save denoised latents as npy files. - ext = "jpg" if args.num_frames == 1 else ext - if not isinstance(args.text_prompt, list): - args.text_prompt = [args.text_prompt] - # if input is a text file, where each line is a caption, load it into a list - if len(args.text_prompt) == 1 and args.text_prompt[0].endswith("txt"): - captions = open(args.text_prompt[0], "r").readlines() - args.text_prompt = [i.strip() for i in captions] - if len(args.text_prompt) == 1 and args.text_prompt[0].endswith("csv"): - captions = pd.read_csv(args.text_prompt[0]) - args.text_prompt = [i.strip() for i in captions["cap"]] - n = len(args.text_prompt) - assert n > 0, "No captions provided" - logger.info(f"Number of prompts: {n}") - logger.info(f"Number of generated samples for each prompt {args.num_videos_per_prompt}") - - # create dataloader for the captions - csv_file = {"path": [], "cap": []} - for i in range(n): - for i_video in range(args.num_videos_per_prompt): - csv_file["path"].append(f"{i_video}-{args.text_prompt[i].strip()[:100]}.{ext}") - csv_file["cap"].append(args.text_prompt[i]) - temp_dataset_csv = os.path.join(save_dir, "dataset.csv") - pd.DataFrame.from_dict(csv_file).to_csv(temp_dataset_csv, index=False, columns=csv_file.keys()) - - ds_config = dict( - data_file_path=temp_dataset_csv, - tokenizer=None, # tokenizer, - file_column="path", - caption_column="cap", - ) - dataset = create_dataloader( - ds_config, - args.batch_size, - ds_name="text", - num_parallel_workers=12, - max_rowsize=32, - shuffle=False, # be in order - device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), - rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, - drop_remainder=False, - ) - dataset_size = dataset.get_dataset_size() - logger.info(f"Num batches: {dataset_size}") - ds_iter = dataset.create_dict_iterator(1, output_numpy=True) - + param.requires_grad = False + if args.decode_latents: - for step, data in tqdm(enumerate(ds_iter), total=dataset_size): - file_paths = data["file_path"] - loaded_latents = [] - for i_sample in range(args.batch_size): - save_fp = os.path.join(save_dir, file_paths[i_sample]) - assert os.path.exists( - save_fp - ), f"{save_fp} does not exist! Please check the npy files under {save_dir} or check if you run `--save_latents` ahead." - loaded_latents.append(np.load(save_fp)) - loaded_latents = ( - np.stack(loaded_latents) if loaded_latents[0].ndim == 4 else np.concatenate(loaded_latents, axis=0) + print("To decode latents directly, skipped loading text endoers and transformer") + return vae + + # Build text encoders + print_banner("text encoder init") + text_encoder_dtype = get_precision(args.text_encoder_precision) + if 'mt5' in args.text_encoder_name_1: + text_encoder_1, loading_info = MT5EncoderModel.from_pretrained( + args.text_encoder_name_1, + cache_dir=args.cache_dir, + output_loading_info=True, + mindspore_dtype=text_encoder_dtype, + use_safetensors=True + ) + loading_info.pop("unexpected_keys") # decoder weights are ignored + logger.info(f"Loaded MT5 Encoder: {loading_info}") + text_encoder_1 = text_encoder_1.set_train(False) + else: + text_encoder_1 = T5EncoderModel.from_pretrained( + args.text_encoder_name_1, cache_dir=args.cache_dir, + mindspore_dtype=text_encoder_dtype + ).set_train(False) + tokenizer_1 = AutoTokenizer.from_pretrained( + args.text_encoder_name_1, cache_dir=args.cache_dir + ) + + if args.text_encoder_name_2 is not None: + text_encoder_2, loading_info = CLIPTextModelWithProjection.from_pretrained( + args.text_encoder_name_2, + cache_dir=args.cache_dir, + mindspore_dtype=text_encoder_dtype, + output_loading_info=True, + use_safetensors=True, ) - decode_data = ( - vae.decode(ms.Tensor(loaded_latents)).permute(0, 1, 3, 4, 2).to(ms.float32) - ) # (b t c h w) -> (b t h w c) - decode_data = ms.ops.clip_by_value( - (decode_data + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0 - ).asnumpy() - for i_sample in range(args.batch_size): - save_fp = os.path.join(save_dir, file_paths[i_sample]).replace(".npy", f".{args.video_extension}") - save_video_data = decode_data[i_sample : i_sample + 1] - save_videos(save_video_data, save_fp, loop=0, fps=args.fps) # (b t h w c) - sys.exit() + loading_info.pop("unexpected_keys") # only load text model, ignore vision model + logger.info(f"Loaded CLIP Encoder: {loading_info}") # Note: missed keys when loading open-clip models + text_encoder_2 = text_encoder_2.set_train(False) + tokenizer_2 = AutoTokenizer.from_pretrained( + args.text_encoder_name_2, cache_dir=args.cache_dir + ) + else: + text_encoder_2, tokenizer_2 = None, None - # 4. latte model initiate and weight loading + # Build transformer print_banner("transformer model init") FA_dtype = get_precision(args.precision) if get_precision(args.precision) != ms.float32 else ms.bfloat16 assert args.model_type == "dit", "Currently only suppport model_type as 'dit'@" @@ -319,21 +347,54 @@ def parse_args(): ) else: state_dict = None - model_version = args.model_path.split("/")[-1] - if int(model_version.split("x")[0]) != args.num_frames: - logger.warning( - f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {model_version.split('x')[0]}" - ) - if int(model_version.split("x")[1][:-1]) != args.height: - logger.warning( - f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}" - ) - transformer_model, logging_info = OpenSoraT2V.from_pretrained( - args.model_path, state_dict=state_dict, cache_dir=args.cache_dir, FA_dtype=FA_dtype, output_loading_info=True - ) - logger.info(logging_info) - # mixed precision + if (args.version != 'v1_3') and (model_version.split("x")[0][:3] != "any"): + if int(model_version.split("x")[0]) != args.num_frames: + logger.warning( + f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {model_version.split('x')[0]}" + ) + if int(model_version.split("x")[1][:-1]) != args.height: + logger.warning( + f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}" + ) + elif (args.version == 'v1_3') and (model_version.split("x")[0] == "any93x640x640"): # TODO: currently only release one model + if (args.height % 32 != 0) or (args.width % 32 != 0): + logger.warning( + f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}. The resolution of the inference should be a multiple of 32." + ) + if (args.num_frames - 1) % 4 != 0: + logger.warning( + f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {args.num_frames}. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image)" + ) + + if args.version == 'v1_3': + # TODO + # if args.model_type == 'inpaint' or args.model_type == 'i2v': + # transformer_model = OpenSoraInpaint_v1_3.from_pretrained( + # args.model_path, cache_dir=args.cache_dir, + # device_map=None, mindspore_dtype=weight_dtype + # ).set_train(False) + # else: + transformer_model, logging_info = OpenSoraT2V_v1_3.from_pretrained( + args.model_path, + state_dict=state_dict, + cache_dir=args.cache_dir, + FA_dtype = FA_dtype, + output_loading_info=True + ) + logger.info(logging_info) + elif args.version == 'v1_5': + if args.model_type == 'inpaint' or args.model_type == 'i2v': + raise NotImplementedError('Inpainting model is not available in v1_5') + else: + from opensora.models.diffusion.opensora_v1_5.modeling_opensora import OpenSoraT2V_v1_5 + transformer_model = OpenSoraT2V_v1_5.from_pretrained( + args.model_path, cache_dir=args.cache_dir, + # device_map=None, + mindspore_dtype=weight_dtype + ) + + # Mixed precision dtype = get_precision(args.precision) if args.precision in ["fp16", "bf16"]: if not args.global_bf16: @@ -363,47 +424,28 @@ def parse_args(): amp_level = "O0" else: raise ValueError(f"Unsupported precision {args.precision}") - transformer_model = transformer_model.set_train(False) for param in transformer_model.get_parameters(): # freeze transformer_model param.requires_grad = False - print_banner("text encoder init") - text_encoder_dtype = get_precision(args.text_encoder_precision) - text_encoder, loading_info = MT5EncoderModel.from_pretrained( - args.text_encoder_name, - cache_dir=args.cache_dir, - output_loading_info=True, - mindspore_dtype=text_encoder_dtype, - use_safetensors=True, - ) - loading_info.pop("unexpected_keys") # decoder weights are ignored - logger.info(loading_info) - tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir) - - # 3. build inference pipeline - if args.sample_method == "DDIM": - scheduler = DDIMScheduler() - elif args.sample_method == "DDPM": - scheduler = DDPMScheduler() - elif args.sample_method == "PNDM": - scheduler = PNDMScheduler() - elif args.sample_method == "EulerDiscrete": - scheduler = EulerDiscreteScheduler() - elif args.sample_method == "EulerAncestralDiscrete": - scheduler = EulerAncestralDiscreteScheduler() - else: - raise ValueError(f"Not supported sampling method {args.sample_method}") + # Build scheduler + scheduler = get_scheduler(args) + + # Build inference pipeline + # pipeline_class = OpenSoraInpaintPipeline if args.model_type == 'inpaint' or args.model_type == 'i2v' else OpenSoraPipeline + pipeline_class = OpenSoraPipeline - pipeline = OpenSoraPipeline( + pipeline = pipeline_class( vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, + text_encoder=text_encoder_1, + tokenizer=tokenizer_1, scheduler=scheduler, - transformer=transformer_model, + transformer=transformer_model, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, ) - # 4. print key info + # Print key info num_params_vae, num_params_vae_trainable = count_params(vae) num_params_latte, num_params_latte_trainable = count_params(transformer_model) num_params = num_params_vae + num_params_latte @@ -413,7 +455,7 @@ def parse_args(): [ f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", f"Jit level: {args.jit_level}", - f"Num of samples: {n}", + f"Num of samples: {len(args.text_prompt)}", f"Num params: {num_params:,} (latte: {num_params_latte:,}, vae: {num_params_vae:,})", f"Num trainable params: {num_params_trainable:,}", f"Transformer dtype: {dtype}", @@ -428,41 +470,176 @@ def parse_args(): ) key_info += "\n" + "=" * 50 logger.info(key_info) - start_time = time.time() - if args.profile: - profiler = ms.Profiler(output_path="./mem_info", profile_memory=True) - ms.set_context(memory_optimize_level="O0") - ms.set_context(pynative_synchronize=True) - else: - profiler = None - # infer - for step, data in tqdm(enumerate(ds_iter), total=dataset_size): + + return pipeline + +def run_model_and_save_samples(args, pipeline, rank_id, device_num, caption_refiner_model=None, enhance_video_model=None): + + # Handle input text prompts + print_banner("text prompts loading") + ext = ( + f"{args.video_extension}" if not (args.save_latents or args.decode_latents) else "npy" + ) # save video as gif or save denoised latents as npy files. + ext = "jpg" if args.num_frames == 1 else ext + if not isinstance(args.text_prompt, list): + args.text_prompt = [args.text_prompt] + # if input is a text file, where each line is a caption, load it into a list + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith("txt"): + captions = open(args.text_prompt[0], "r").readlines() + args.text_prompt = [i.strip() for i in captions] + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith("csv"): + captions = pd.read_csv(args.text_prompt[0]) + args.text_prompt = [i.strip() for i in captions["cap"]] + n = len(args.text_prompt) + assert n > 0, "No captions provided" + logger.info(f"Number of prompts: {n}") + logger.info(f"Number of generated samples for each prompt {args.num_videos_per_prompt}") + + # Create dataloader for the captions + csv_file = {"path": [], "cap": []} + for i in range(n): + for i_video in range(args.num_videos_per_prompt): + csv_file["path"].append(f"{i_video}-{args.text_prompt[i].strip()[:100]}.{ext}") + csv_file["cap"].append(args.text_prompt[i]) + temp_dataset_csv = os.path.join(save_dir, "dataset.csv") + pd.DataFrame.from_dict(csv_file).to_csv(temp_dataset_csv, index=False, columns=csv_file.keys()) + + ds_config = dict( + data_file_path=temp_dataset_csv, + tokenizer=None, # tokenizer, + file_column="path", + caption_column="cap", + ) + dataset = create_dataloader( + ds_config, + args.batch_size, + ds_name="text", + num_parallel_workers=12, + max_rowsize=32, + shuffle=False, # be in order + device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), + rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, + drop_remainder=False, + ) + dataset_size = dataset.get_dataset_size() + logger.info(f"Num batches: {dataset_size}") + ds_iter = dataset.create_dict_iterator(1, output_numpy=True) + + # Decode latents directly + if args.decode_latents: + print("Directly decoding latents...") + assert isinstance(pipeline, ae_wrapper[args.ae]) + vae = pipeline + for step, data in tqdm(enumerate(ds_iter), total=dataset_size): + file_paths = data["file_path"] + loaded_latents = [] + for i_sample in range(args.batch_size): + save_fp = os.path.join(save_dir, file_paths[i_sample]) + assert os.path.exists( + save_fp + ), f"{save_fp} does not exist! Please check the npy files under {save_dir} or check if you run `--save_latents` ahead." + loaded_latents.append(np.load(save_fp)) + loaded_latents = ( + np.stack(loaded_latents) if loaded_latents[0].ndim == 4 else np.concatenate(loaded_latents, axis=0) + ) + decode_data = ( + vae.decode(ms.Tensor(loaded_latents)).permute(0, 1, 3, 4, 2).to(ms.float32) + ) # (b t c h w) -> (b t h w c) + decode_data = ms.ops.clip_by_value( + (decode_data + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0 + ).asnumpy() + for i_sample in range(args.batch_size): + save_fp = os.path.join(save_dir, file_paths[i_sample]).replace(".npy", f".{args.video_extension}") + save_video_data = decode_data[i_sample : i_sample + 1] + save_videos(save_video_data, save_fp, loop=0, fps=args.fps) # (b t h w c) + + # Delete files that are no longer needed + if os.path.exists(temp_dataset_csv): + os.remove(temp_dataset_csv) + + if args.decode_latents: + npy_files = glob.glob(os.path.join(save_dir, "*.npy")) + for fp in npy_files: + os.remove(fp) + + # TODO + # if args.model_type == 'inpaint' or args.model_type == 'i2v': + # if not isinstance(args.conditional_pixel_values_path, list): + # args.conditional_pixel_values_path = [args.conditional_pixel_values_path] + # if len(args.conditional_pixel_values_path) == 1 and args.conditional_pixel_values_path[0].endswith('txt'): + # temp = open(args.conditional_pixel_values_path[0], 'r').readlines() + # conditional_pixel_values_path = [i.strip().split(',') for i in temp] + # mask_type = args.mask_type if args.mask_type is not None else None + + # positive_prompt = """ + # high quality, high aesthetic, {} + # """ + # negative_prompt = """ + # nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, + # low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. + # """ + positive_prompt = ( + "(masterpiece), (best quality), (ultra-detailed), {}. emotional, " + + "harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous" + ) + negative_prompt = ( + "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, " + + "extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" + ) + + def generate(step, data, ext, conditional_pixel_values_path=None, mask_type=None): + + # TODO + # if args.caption_refiner is not None: + # if args.model_type != 'inpaint' and args.model_type != 'i2v': + # refine_prompt = caption_refiner_model.get_refiner_output(prompt) + # print(f'\nOrigin prompt: {prompt}\n->\nRefine prompt: {refine_prompt}') + # prompt = refine_prompt + # else: + # # Due to the current use of LLM as the caption refiner, additional content that is not present in the control image will be added. Therefore, caption refiner is not used in this mode. + # print('Caption refiner is not available for inpainting model, use the original prompt...') + # time.sleep(3) + # input_prompt = positive_prompt.format(prompt) + # if args.model_type == 'inpaint' or args.model_type == 'i2v': + # print(f'\nConditional pixel values path: {conditional_pixel_values_path}') + # videos = pipeline( + # conditional_pixel_values_path=conditional_pixel_values_path, + # mask_type=mask_type, + # crop_for_hw=args.crop_for_hw, + # max_hxw=args.max_hxw, + # prompt=input_prompt, + # negative_prompt=negative_prompt, + # num_frames=args.num_frames, + # height=args.height, + # width=args.width, + # num_inference_steps=args.num_sampling_steps, + # guidance_scale=args.guidance_scale, + # num_samples_per_prompt=args.num_samples_per_prompt, + # max_sequence_length=args.max_sequence_length, + # ).videos + # else: prompt = [x for x in data["caption"]] file_paths = data["file_path"] - positive_prompt = ( - "(masterpiece), (best quality), (ultra-detailed), {}. emotional, " - + "harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous" - ) - negative_prompt = ( - "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, " - + "extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" - ) - + input_prompt = positive_prompt.format(prompt) videos = ( pipeline( - positive_prompt.format(prompt), - negative_prompt=negative_prompt, + input_prompt, + negative_prompt=negative_prompt, num_frames=args.num_frames, height=args.height, width=args.width, num_inference_steps=args.num_sampling_steps, guidance_scale=args.guidance_scale, + num_samples_per_prompt=args.num_samples_per_prompt, output_type="latents" if args.save_latents else "pil", max_sequence_length=args.max_sequence_length, ) - .images.to(ms.float32) + .videos.to(ms.float32) .asnumpy() ) + # if enhance_video_model is not None: + # # b t h w c + # videos = enhance_video_model.enhance_a_video(videos, input_prompt, 2.0, args.fps, 250) if step == 0 and profiler is not None: profiler.stop() @@ -485,17 +662,73 @@ def parse_args(): save_video_data = videos[i_sample : i_sample + 1] # (b t h w c) save_videos(save_video_data, file_path, loop=0, fps=args.fps) - end_time = time.time() - time_cost = end_time - start_time - logger.info(f"Inference time cost: {time_cost:0.3f}s") - logger.info(f"Inference speed: {n / time_cost:0.3f} samples/s") - logger.info(f"{'latents' if args.save_latents else 'videos' } saved to {save_dir}") + if args.profile: + profiler = ms.Profiler(output_path="./mem_info", profile_memory=True) + ms.set_context(memory_optimize_level="O0") + ms.set_context(pynative_synchronize=True) + else: + profiler = None - # delete files that are no longer needed + # Infer + # if args.model_type == 'inpaint' or args.model_type == 'i2v': + # for index, (prompt, cond_path) in enumerate(zip(args.text_prompt, conditional_pixel_values_path)): + # if not args.sp and args.local_rank != -1 and index % args.world_size != args.local_rank: + # continue + # generate(prompt, conditional_pixel_values_path=cond_path, mask_type=mask_type) + # print('completed, please check the saved images and videos') + # else: + for step, data in tqdm(enumerate(ds_iter), total=dataset_size): + generate(step, data, ext) + + + # Delete files that are no longer needed if os.path.exists(temp_dataset_csv): os.remove(temp_dataset_csv) - if args.decode_latents: - npy_files = glob.glob(os.path.join(save_dir, "*.npy")) - for fp in npy_files: - os.remove(fp) + + +if __name__ == "__main__": + args = parse_args() + save_dir = args.save_img_path + os.makedirs(save_dir, exist_ok=True) + set_logger(name="", output_dir=save_dir) + + # 1. init environment + rank_id, device_num = npu_config.set_npu_env(args) + # rank_id, device_num = init_env( + # args.mode, + # seed=args.seed, + # distributed=args.use_parallel, + # device_target=args.device, + # max_device_memory=args.max_device_memory, + # parallel_mode=args.parallel_mode, + # precision_mode=args.precision_mode, + # global_bf16=args.global_bf16, + # sp_size=args.sp_size, + # jit_level=args.jit_level, + # jit_syntax_level=args.jit_syntax_level, + # ) + + # 2. build models and pipeline + if args.num_frames != 1 and args.enhance_video is not None: + from opensora.sample.VEnhancer.enhance_a_video import VEnhancer + enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=args.device) + else: + enhance_video_model = None + pipeline = prepare_pipeline(args) + + # TODO + # if args.caption_refiner is not None: + # caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device) + # else: + caption_refiner_model = None + + # 3. inference + start_time = time.time() + run_model_and_save_samples(args, pipeline, rank_id, device_num, caption_refiner_model, enhance_video_model) + end_time = time.time() + time_cost = end_time - start_time + logger.info(f"Inference time cost: {time_cost:0.3f}s") + logger.info(f"Inference speed: {len(args.text_prompt) / time_cost:0.3f} samples/s") + logger.info(f"{'latents' if args.save_latents else 'videos' } saved to {save_dir}") + \ No newline at end of file diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh new file mode 100644 index 0000000000..665483b0ca --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh @@ -0,0 +1,26 @@ +# The DiT model is trained arbitrarily on stride=32. +# So keep the resolution of the inference a multiple of 32. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image). + +export DEVICE_ID=0 +python opensora/sample/sample.py \ + --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ + --version v1_3 \ + --num_frames 29 \ + --height 704 \ + --width 1280 \ + --text_encoder_name_1 google/mt5-xxl \ + --text_encoder_name_2 laion/CLIP-ViT-bigG-14-laion2B-39B-b160k \ + --text_prompt examples/prompt_list_0.txt \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --save_img_path "./sample_videos/prompt_list_0_29x1280_mt5_openclip" \ + --fps 24 \ + --guidance_scale 7.5 \ + --num_sampling_steps 100 \ + --enable_tiling \ + --max_sequence_length 512 \ + --sample_method EulerAncestralDiscrete \ + --num_samples_per_prompt 1 \ + --rescale_betas_zero_snr \ + --prediction_type "v_prediction" \ + --mode 1 \ No newline at end of file diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x720p.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x720p.sh deleted file mode 100644 index 6d6a726b09..0000000000 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x720p.sh +++ /dev/null @@ -1,19 +0,0 @@ -export DEVICE_ID=0 -python opensora/sample/sample_t2v.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.2.0/29x720p \ - --num_frames 29 \ - --height 720 \ - --width 1280 \ - --cache_dir "./" \ - --text_encoder_name google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae\ - --save_img_path "./sample_videos/prompt_list_0_29x720p" \ - --fps 24 \ - --guidance_scale 7.5 \ - --num_sampling_steps 100 \ - --enable_tiling \ - --max_sequence_length 512 \ - --sample_method EulerAncestralDiscrete \ - --model_type "dit" \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_1texenc.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_1texenc.sh new file mode 100644 index 0000000000..4f28620f74 --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_1texenc.sh @@ -0,0 +1,25 @@ +# The DiT model is trained arbitrarily on stride=32. +# So keep the resolution of the inference a multiple of 32. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image). + +export DEVICE_ID=0 +python opensora/sample/sample_v1_3.py \ + --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ + --version v1_3 \ + --num_frames 93 \ + --height 352 \ + --width 640 \ + --text_encoder_name_1 google/mt5-xxl \ + --text_prompt examples/prompt_list_0.txt \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --save_img_path "./sample_videos/prompt_list_0_93x640_mt5" \ + --fps 18 \ + --guidance_scale 7.5 \ + --num_sampling_steps 100 \ + --enable_tiling \ + --max_sequence_length 512 \ + --sample_method EulerAncestralDiscrete \ + --num_samples_per_prompt 1 \ + --rescale_betas_zero_snr \ + --prediction_type "v_prediction" \ + --mode 1 diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh new file mode 100644 index 0000000000..513a9c3925 --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh @@ -0,0 +1,23 @@ +export DEVICE_ID=0 +python opensora/sample/sample_v1_3.py \ + --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ + --version v1_3 \ + --num_frames 93 \ + --height 352 \ + --width 640 \ + --text_encoder_name_1 google/mt5-xxl \ + --text_encoder_name_2 laion/CLIP-ViT-bigG-14-laion2B-39B-b160k \ + --text_prompt examples/prompt_list_0.txt \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --save_img_path "./sample_videos/prompt_list_0_93x640_mt5_clipbigG" \ + --fps 18 \ + --guidance_scale 7.5 \ + --num_sampling_steps 100 \ + --enable_tiling \ + --max_sequence_length 512 \ + --sample_method EulerAncestralDiscrete \ + --num_samples_per_prompt 1 \ + --rescale_betas_zero_snr \ + --prediction_type "v_prediction" \ + --mode 1 From 164bb089718e737a6cd042cd81ad4f8575101fa6 Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Mon, 4 Nov 2024 12:23:46 +0800 Subject: [PATCH 003/133] align structure --- examples/opensora_pku/README.md | 13 +- .../diffusion/{opensora/rope.py => common.py} | 31 +- .../diffusion/opensora/modeling_opensora.py | 30 +- .../models/diffusion/opensora/modules.py | 31 +- .../opensora/sample/pipeline_opensora.py | 51 +- .../opensora_pku/opensora/sample/sample.py | 721 +---------------- .../opensora/utils/sample_utils.py | 726 ++++++++++++++++++ .../single-device/sample_debug.sh | 36 + .../single-device/sample_t2v_29x1280.sh | 3 +- .../single-device/sample_t2v_29x480p.sh | 17 +- ...93x640_1texenc.sh => sample_t2v_93x640.sh} | 2 +- .../sample_t2v_93x640_2texenc.sh | 5 +- .../readme_load_states.md | 210 +++++ 13 files changed, 1088 insertions(+), 788 deletions(-) rename examples/opensora_pku/opensora/models/diffusion/{opensora/rope.py => common.py} (80%) create mode 100644 examples/opensora_pku/opensora/utils/sample_utils.py create mode 100644 examples/opensora_pku/scripts/text_condition/single-device/sample_debug.sh rename examples/opensora_pku/scripts/text_condition/single-device/{sample_t2v_93x640_1texenc.sh => sample_t2v_93x640.sh} (95%) create mode 100644 examples/opensora_pku/torch_intermediate_states/readme_load_states.md diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index b72bbad8e5..4892957897 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -183,7 +183,7 @@ First, you need to download checkpoint including [diffusion model](https://huggi -You can run text-to-video inference on a single Ascend device using the script `scripts/text_condition/single-device/sample_t2v_29x1280.sh` by modifying `--model_path`, `--text_encoder_name_1` and `--ae_path`. The `--caption_refiner` and `--text_encoder_name_2` are optional. +You can run text-to-video inference on a single Ascend device using the script `scripts/text_condition/single-device/sample_t2v_93x640.sh` by modifying `--model_path`, `--text_encoder_name_1` and `--ae_path`. The `--caption_refiner` and `--text_encoder_name_2` are optional. @@ -192,16 +192,15 @@ You can run text-to-video inference on a single Ascend device using the script ` python opensora/sample/sample.py \ --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ - --num_frames 29 \ - --height 704 \ - --width 1280 \ + --num_frames 93 \ + --height 352 \ + --width 640 \ --text_encoder_name_1 google/mt5-xxl \ - --text_encoder_name_2 laion/CLIP-ViT-bigG-14-laion2B-39B-b160k \ --text_prompt examples/prompt_list_0.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/prompt_list_0_29x1280_mt5_openclip" \ - --fps 24 \ + --save_img_path "./sample_videos/prompt_list_0_93x640" \ + --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ --enable_tiling \ diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/rope.py b/examples/opensora_pku/opensora/models/diffusion/common.py similarity index 80% rename from examples/opensora_pku/opensora/models/diffusion/opensora/rope.py rename to examples/opensora_pku/opensora/models/diffusion/common.py index f559d1a5b3..e37fb528fb 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/rope.py +++ b/examples/opensora_pku/opensora/models/diffusion/common.py @@ -7,6 +7,35 @@ from mindspore import mint, nn, ops +# V1.3, Different from v1.2 +class PatchEmbed2D(nn.Cell): + """2D Image to Patch Embedding but with video""" + + def __init__( + self, + patch_size=16, #2 + in_channels=3, #8 + embed_dim=768, # 24*96=2304 + bias=True, + ): + super().__init__() + self.proj = nn.Conv2d( + in_channels, embed_dim, + kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), has_bias=bias + ) + + def construct(self, latent): + b, c, t, h, w = latent.shape # b, c=in_channels, t, h, w + # b c t h w -> (b t) c h w + latent = latent.permute(0, 2, 1, 3, 4).reshape(b*t, c, h, w) # b*t, c, h, w + latent = self.proj(latent) # b*t, embed_dim, h, w + # (b t) c h w -> b (t h w) c + _, c, h, w = latent.shape + latent = latent.reshape(b, -1, c, h, w).permute(0, 1, 3, 4, 2).reshape(b, -1, c) # b, t*h*w, embed_dim + + return latent + + class PositionGetter3D(object): """return positions of patches""" @@ -98,4 +127,4 @@ def construct(self, tokens, positions): y = self.apply_rope1d(y, poses[1], cos_y.to(tokens.dtype), sin_y.to(tokens.dtype)) x = self.apply_rope1d(x, poses[2], cos_x.to(tokens.dtype), sin_x.to(tokens.dtype)) tokens = ops.cat((t, y, x), axis=-1) - return tokens + return tokens \ No newline at end of file diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 0d85fafba3..e55a7b8c23 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -15,9 +15,10 @@ from mindone.diffusers.models.embeddings import PixArtAlphaTextProjection from mindone.diffusers.models.modeling_utils import ModelMixin, load_state_dict from mindone.diffusers.models.normalization import AdaLayerNormSingle -from mindone.diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, deprecate +from mindone.diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file -from examples.opensora_pku.opensora.models.diffusion.opensora.modules import BasicTransformerBlock, LayerNorm, Attention, PatchEmbed2D +from opensora.models.diffusion.opensora.modules import BasicTransformerBlock, LayerNorm, Attention +from opensora.models.diffusion.common import PatchEmbed2D class OpenSoraT2V_v1_3(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @@ -115,7 +116,8 @@ def _init_patched_inputs(self): self.adaln_single = AdaLayerNormSingle(self.config.hidden_size) self.max_pool3d = nn.MaxPool3d( kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), - stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size) + stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), + pad_mode="pad" ) # rewrite class method to allow the state dict as input @@ -266,7 +268,7 @@ def construct( # return_dict: bool = True, **kwargs, ): - + dtype = ms.float16 batch_size, c, frame, h, w = hidden_states.shape # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. @@ -306,9 +308,8 @@ def construct( frame = ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t # patchfy height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size - hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( - hidden_states, encoder_hidden_states, timestep, batch_size, frame + hidden_states, encoder_hidden_states, timestep, batch_size, frame, dtype=dtype ) if get_sequence_parallel_state(): @@ -352,6 +353,7 @@ def construct( width=width, ) # BSH + if get_sequence_parallel_state(): # To (b, t*h*w, h) or (b, t//sp*h*w, h) # s b h -> b s h @@ -367,16 +369,15 @@ def construct( width=width, ) # b c t h w - return output - - def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame): - - hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # (b, t*h*w, d) + return output + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame, dtype=ms.float16): + hidden_states = self.pos_embed(hidden_states.to(dtype)) # (b, t*h*w, d) + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=dtype ) # b 6d, b d encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1, l, d @@ -387,7 +388,6 @@ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, times return hidden_states, encoder_hidden_states, timestep, embedded_timestep - def _get_output_for_patched_inputs( self, hidden_states, timestep, embedded_timestep, num_frames, height, width ): @@ -432,7 +432,7 @@ def OpenSoraT2V_v1_3_2B_122(**kwargs): args = type('args', (), { - 'ae': "CausalVAEModel_D4_4x8x8", #'WFVAEModel_D8_4x8x8', + 'ae': "WFVAEModel_D8_4x8x8", 'model_max_length': 300, 'max_height': 256, 'max_width': 512, @@ -480,6 +480,8 @@ def OpenSoraT2V_v1_3_2B_122(**kwargs): ckpt = safe_load(path, device="cpu") msg = model.load_state_dict(ckpt, strict=True) print(msg) + # some difference from sample.py + # e.g. do not have mix precision except Exception as e: print(e) # print(model) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index 212846dae0..84c9a83abb 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -14,7 +14,7 @@ from mindone.diffusers.models.attention_processor import Attention as Attention_ from mindone.utils.version_control import check_valid_flash_attention, choose_flash_attention_dtype -from .rope import PositionGetter3D, RoPE3D +from ..common import PositionGetter3D, RoPE3D logger = logging.getLogger(__name__) @@ -38,35 +38,6 @@ def construct(self, x: ms.Tensor): x, _, _ = self.layer_norm(x, self.gamma, self.beta) return x -# Different from v1.2 -class PatchEmbed2D(nn.Cell): - """2D Image to Patch Embedding but with video""" - - def __init__( - self, - patch_size=16, #2 - in_channels=3, #8 - embed_dim=768, # 24*96=2304 - bias=True, - ): - super().__init__() - self.proj = nn.Conv2d( - in_channels, embed_dim, - kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), has_bias=bias - ) - - def construct(self, latent): - b, c, t, h, w = latent.shape # b, c=in_channels, t, h, w - # b c t h w -> (b t) c h w - latent = latent.permute(0, 2, 1, 3, 4).reshape(b*t, c, h, w) # b*t, c, h, w - latent = self.proj(latent) # b*t, embed_dim, h, w - # (b t) c h w -> b (t h w) c - _, c, h, w = latent.shape - latent = latent.reshape(b, -1, c, h, w).permute(0, 1, 3, 4, 2).reshape(b, -1, c) # b, t*h*w, embed_dim - - return latent - - def get_attention_mask(attention_mask, repeat_num, attention_mode="xformers"): if attention_mask is not None: if attention_mode != "math": diff --git a/examples/opensora_pku/opensora/sample/pipeline_opensora.py b/examples/opensora_pku/opensora/sample/pipeline_opensora.py index a43eddec01..77f120cbdb 100644 --- a/examples/opensora_pku/opensora/sample/pipeline_opensora.py +++ b/examples/opensora_pku/opensora/sample/pipeline_opensora.py @@ -16,6 +16,7 @@ from mindone.diffusers.utils import BACKENDS_MAPPING, deprecate, is_bs4_available, is_ftfy_available, BaseOutput from mindone.diffusers import AutoencoderKL from mindone.diffusers import DDPMScheduler, FlowMatchEulerDiscreteScheduler +from mindone.diffusers.utils.mindspore_utils import randn_tensor # from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback #TODO:TBD from mindone.transformers import CLIPTextModelWithProjection, T5EncoderModel @@ -46,15 +47,14 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), ddof=True, keepdims=True) - std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), ddof=True, keepdims=True) + std_text = ops.std(noise_pred_text, axis=tuple(range(1, len(noise_pred_text.shape))), keepdims=True) + std_cfg = ops.std(noise_cfg, axis=tuple(range(1, len(noise_cfg.shape))), keepdims=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -238,15 +238,15 @@ def encode_prompt( padding="max_length", max_length=max_length, truncation=True, - # return_attention_mask=True, + return_attention_mask=True, return_tensors=None, ) text_input_ids = ms.Tensor(text_inputs.input_ids) untruncated_ids = ms.Tensor(tokenizer(prompt, padding="longest", return_tensors=None).input_ids) if ( - untruncated_ids.shape[-1] >= text_input_ids.shape[-1] - and not ops.equal(text_input_ids, untruncated_ids[:, :text_input_ids.shape[-1]]).all() + untruncated_ids.shape[-1] > text_input_ids.shape[-1] or + (untruncated_ids.shape[-1] == text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids).all()) ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( @@ -439,14 +439,14 @@ def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, ) if latents is None: - latents = ops.randn(shape, dtype=dtype) #generator=generator + latents = randn_tensor(shape=shape, generator=generator, dtype=dtype) else: latents = latents if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - return latents + return latents.to(dtype) def prepare_parallel_latent(self, video_states): sp_size = hccl_info.world_size @@ -498,7 +498,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_samples_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, - generator = None, + generator: Optional[np.random.Generator] = None, latents: Optional[ms.Tensor] = None, prompt_embeds: Optional[ms.Tensor] = None, prompt_embeds_2: Optional[ms.Tensor] = None, @@ -513,9 +513,12 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, ms.Tensor], None]] = None, # Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] callback_on_step_end_tensor_inputs: List[str] = ["latents"], guidance_rescale: float = 0.0, - max_sequence_length: int = 512 + max_sequence_length: int = 512, ): - # TODO + + # TODO + if hasattr(callback_on_step_end, 'tensor_inputs'): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): # callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -524,6 +527,7 @@ def __call__( height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2] + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -563,7 +567,7 @@ def __call__( negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, - dtype=self.transformer.dtype, + dtype=ms.float16, #self.transformer.dtype, num_samples_per_prompt=num_samples_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, @@ -574,6 +578,7 @@ def __call__( max_sequence_length=max_sequence_length, text_encoder_index=0, ) + if self.tokenizer_2 is not None: ( prompt_embeds_2, @@ -582,7 +587,7 @@ def __call__( negative_prompt_attention_mask_2, ) = self.encode_prompt( prompt=prompt, - dtype=self.transformer.dtype, + dtype=ms.float16, #self.transformer.dtype, num_samples_per_prompt=num_samples_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, @@ -663,15 +668,14 @@ def __call__( prompt_embeds = prompt_embeds[:, index, :, :] else: temp_attention_mask = None - # ==================make sp===================================== + # ==================make sp===================================== # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue - # print(f"t {t}:, latents {latents.shape}") # expand the latents if we are doing classifier free guidance latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): @@ -692,7 +696,8 @@ def __call__( if prompt_attention_mask.ndim == 2: prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2: - prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d + prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d #OFFICIAL VER. DONT KNOW WHY + # prompt_embeds_2 = prompt_embeds_2.unsqueeze(1) # attention_mask = ops.ones_like(latent_model_input)[:, 0] if temp_attention_mask is not None: @@ -715,11 +720,12 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=current_timestep, - pooled_projections=prompt_embeds_2, + pooled_projections=prompt_embeds_2, # UNUSED!!!! return_dict=False, ) ) # b,c,t,h,w assert not ops.any(ops.isnan(noise_pred.float())) + # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -772,20 +778,19 @@ def __call__( return (videos, ) return OpenSoraPipelineOutput(videos=videos) - # TODO: update OpenSoraPipelineOutput according to VAE usage - + # def decode_latents(self, latents): # print(f'before vae decode {latents.shape}', ops.max(latents).item(), ops.min(latents).item(), ops.mean(latents).item(), ops.std(latents).item()) - # video = self.vae.decode(latents.to(self.vae.vae.dtype)) + # video = self.vae.decode(latents.to(self.vae.vae.dtype)) # (b t c h w) # print(f'after vae decode {latents.shape}', ops.max(video).item(), ops.min(video).item(), ops.mean(video).item(), ops.std(video).item()) # video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=ms.uint8).permute(0, 1, 3, 4, 2).contiguous() # b t h w c # return video - - #TODO: TBD depending on VAE usage + def decode_latents_per_sample(self, latents): - # video = self.vae.decode(latents) + print(f'before vae decode {latents.shape}', latents.max().item(), latents.min().item(), latents.mean().item(), latents.std().item()) video = self.vae.decode(latents).to(ms.float32) # (b t c h w) + print(f'after vae decode {latents.shape}', latents.max().item(), latents.min().item(), latents.mean().item(), latents.std().item()) video = ops.clip_by_value((video / 2.0 + 0.5), clip_value_min=0.0, clip_value_max=1.0).permute(0, 1, 3, 4, 2) return video # b t h w c diff --git a/examples/opensora_pku/opensora/sample/sample.py b/examples/opensora_pku/opensora/sample/sample.py index 47a5ff8e6f..1320dd90f2 100644 --- a/examples/opensora_pku/opensora/sample/sample.py +++ b/examples/opensora_pku/opensora/sample/sample.py @@ -1,731 +1,48 @@ -import argparse -import glob -import logging -import os -import sys -import time - -import numpy as np -import pandas as pd -import yaml -from PIL import Image -from tqdm import tqdm - -import mindspore as ms -from mindspore import nn - - +import os, sys # TODO: remove in future when mindone is ready for install mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) sys.path.append(os.path.abspath("./")) -from opensora.npu_config import npu_config -from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info -from opensora.dataset.text_dataset import create_dataloader -from opensora.utils.message_utils import print_banner -from opensora.utils.ms_utils import init_env -from opensora.utils.utils import _check_cfgs_in_parser, get_precision -from opensora.models.causalvideovae import ae_stride_config, ae_wrapper -# from opensora.sample.caption_refiner import OpenSoraCaptionRefiner #TODO -from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate -from examples.opensora_pku.opensora.models.diffusion.opensora.modeling_opensora import LayerNorm, OpenSoraT2V_v1_3 -from examples.opensora_pku.opensora.models.diffusion.opensora.modules import Attention -from opensora.sample.pipeline_opensora import OpenSoraPipeline -from mindone.diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings -from mindone.diffusers import ( - DDIMScheduler, DDPMScheduler, PNDMScheduler, - EulerDiscreteScheduler, DPMSolverMultistepScheduler, - HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, - DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler, - DPMSolverSinglestepScheduler, #CogVideoXDDIMScheduler, - FlowMatchEulerDiscreteScheduler - ) -from mindone.transformers import T5EncoderModel, MT5EncoderModel, CLIPTextModelWithProjection -from transformers import AutoTokenizer, MT5Tokenizer +import logging +import time + +from opensora.npu_config import npu_config +from opensora.utils.sample_utils import ( + prepare_pipeline, get_args, run_model_and_save_samples +) +# from opensora.sample.caption_refiner import OpenSoraCaptionRefiner -from mindone.utils.amp import auto_mixed_precision -from mindone.utils.config import str2bool from mindone.utils.logger import set_logger -from mindone.utils.params import count_params -from mindone.visualize.videos import save_videos logger = logging.getLogger(__name__) -# ms.set_context(pynative_synchronize=True) - -# Copied from opensora.utils sample_utils.py -def get_scheduler(args): - kwargs = dict( - prediction_type=args.prediction_type, - rescale_betas_zero_snr=args.rescale_betas_zero_snr, - timestep_spacing="trailing" if args.rescale_betas_zero_snr else 'leading', - ) - if args.v1_5_scheduler: - kwargs['beta_start'] = 0.00085 - kwargs['beta_end'] = 0.0120 - kwargs['beta_schedule'] = "scaled_linear" - if args.sample_method == 'DDIM': - scheduler_cls = DDIMScheduler - kwargs['clip_sample'] = False - elif args.sample_method == 'EulerDiscrete': - scheduler_cls = EulerDiscreteScheduler - elif args.sample_method == 'DDPM': - scheduler_cls = DDPMScheduler - kwargs['clip_sample'] = False - elif args.sample_method == 'DPMSolverMultistep': - scheduler_cls = DPMSolverMultistepScheduler - elif args.sample_method == 'DPMSolverSinglestep': - scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_method == 'PNDM': - scheduler_cls = PNDMScheduler - kwargs.pop('rescale_betas_zero_snr', None) - elif args.sample_method == 'HeunDiscrete': ######## - scheduler_cls = HeunDiscreteScheduler - elif args.sample_method == 'EulerAncestralDiscrete': - scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_method == 'DEISMultistep': - scheduler_cls = DEISMultistepScheduler - kwargs.pop('rescale_betas_zero_snr', None) - elif args.sample_method == 'KDPM2AncestralDiscrete': ######### - scheduler_cls = KDPM2AncestralDiscreteScheduler - # elif args.sample_method == 'CogVideoX': - # scheduler_cls = CogVideoXDDIMScheduler - elif args.sample_method == 'FlowMatchEulerDiscrete': - scheduler_cls = FlowMatchEulerDiscreteScheduler - kwargs = {} - else: - raise NameError(f'Unsupport sample_method {args.sample_method}') - scheduler = scheduler_cls(**kwargs) - return scheduler - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--version", type=str, default='v1_3', choices=['v1_3', 'v1_5']) - parser.add_argument("--caption_refiner", type=str, default=None) - parser.add_argument("--enhance_video", type=str, default=None) - parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl', help="google/mt5-xxl, DeepFloyd/t5-v1_1-xxl") - parser.add_argument("--text_encoder_name_2", type=str, default=None, help=" openai/clip-vit-large-patch14, (laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)") - parser.add_argument("--num_samples_per_prompt", type=int, default=1) - parser.add_argument('--refine_caption', action='store_true') - # parser.add_argument('--compile', action='store_true') - parser.add_argument("--prediction_type", type=str, default='epsilon', help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.") - parser.add_argument('--rescale_betas_zero_snr', action='store_true') - # parser.add_argument('--local_rank', type=int, default=-1) - # parser.add_argument('--world_size', type=int, default=1) - # parser.add_argument('--sp', action='store_true') - - parser.add_argument('--v1_5_scheduler', action='store_true') - parser.add_argument('--conditional_pixel_values_path', type=str, default=None) - parser.add_argument('--mask_type', type=str, default=None) - parser.add_argument('--crop_for_hw', action='store_true') - parser.add_argument('--max_hxw', type=int, default=236544) # 480*480 - - parser.add_argument( - "--config", - "-c", - default="", - type=str, - help="path to load a config yaml file that describes the setting which will override the default arguments", - ) - parser.add_argument("--model_path", type=str, default="LanguageBind/Open-Sora-Plan-v1.3.0") - parser.add_argument( - "--ms_checkpoint", - type=str, - default=None, - help="If not provided, will search for ckpt file under `model_path`" - "If provided, will use this pretrained ckpt path.", - ) - parser.add_argument("--num_frames", type=int, default=1) - parser.add_argument("--height", type=int, default=512) - parser.add_argument("--width", type=int, default=512) - parser.add_argument("--ae", type=str, default="CausalVAEModel_4x8x8") - parser.add_argument("--ae_path", type=str, default="CausalVAEModel_4x8x8") - parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") - - parser.add_argument("--text_encoder_name", type=str, default="DeepFloyd/t5-v1_1-xxl") - parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") - - parser.add_argument("--guidance_scale", type=float, default=7.5, help="the scale for classifier-free guidance") - parser.add_argument("--max_sequence_length", type=int, default=512, help="the maximum text tokens length") - - parser.add_argument("--sample_method", type=str, default="PNDM") - parser.add_argument("--num_sampling_steps", type=int, default=50, help="Diffusion Sampling Steps") - parser.add_argument("--fps", type=int, default=24) - parser.add_argument( - "--text_prompt", - type=str, - nargs="+", - help="A list of text prompts to be generated with. Also allow input a txt file or csv file.", - ) - parser.add_argument("--tile_overlap_factor", type=float, default=0.25) - - parser.add_argument("--enable_tiling", action="store_true", help="whether to use vae tiling to save memory") - parser.add_argument("--model_3d", action="store_true") - parser.add_argument("--udit", action="store_true") - parser.add_argument("--save_memory", action="store_true") - parser.add_argument("--batch_size", default=1, type=int, help="batch size for dataloader") - # MS new args - parser.add_argument("--device", type=str, default="Ascend", help="Ascend or GPU") - parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") - parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") - parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") - parser.add_argument( - "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data, optim" - ) - parser.add_argument("--jit_level", default="O0", help="Set jit level: # O0: KBK, O1:DVM, O2: GE") - parser.add_argument( - "--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax" - ) - parser.add_argument("--seed", type=int, default=42, help="Inference seed") - - parser.add_argument( - "--precision", - default="bf16", - type=str, - choices=["bf16", "fp16", "fp32"], - help="what data type to use for latte. Default is `fp16`, which corresponds to ms.float16", - ) - parser.add_argument( - "--global_bf16", action="store_true", help="whether to enable gloabal bf16 for diffusion model training." - ) - parser.add_argument( - "--vae_precision", - default="fp16", - type=str, - choices=["bf16", "fp16"], - help="what data type to use for vae. Default is `bf16`, which corresponds to ms.bfloat16", - ) - parser.add_argument( - "--vae_keep_gn_fp32", - default=False, - type=str2bool, - help="whether keep GroupNorm in fp32. Defaults to False in inference mode. If training vae, better set it to True", - ) - parser.add_argument( - "--text_encoder_precision", - default="fp16", - type=str, - choices=["bf16", "fp16"], - help="what data type to use for T5 text encoder. Default is `bf16`, which corresponds to ms.bfloat16", - ) - parser.add_argument( - "--amp_level", type=str, default="O2", help="Set the amp level for the transformer model. Defaults to O2." - ) - parser.add_argument( - "--precision_mode", - default=None, - type=str, - help="If specified, set the precision mode for Ascend configurations.", - ) - parser.add_argument( - "--num_videos_per_prompt", type=int, default=1, help="the number of images to be generated for each prompt" - ) - parser.add_argument( - "--save_latents", - action="store_true", - help="Whether to save latents (before vae decoding) instead of video files.", - ) - parser.add_argument( - "--decode_latents", - action="store_true", - help="whether to load the existing latents saved in npy files and run vae decoding", - ) - parser.add_argument( - "--video_extension", default="mp4", choices=["gif", "mp4"], help="The file extension to save videos" - ) - parser.add_argument("--model_type", type=str, default="dit", choices=["dit", "udit", "latte", 't2v', 'inpaint', 'i2v']) - parser.add_argument("--cache_dir", type=str, default="./") - parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not") - default_args = parser.parse_args() - abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) - if default_args.config: - logger.info(f"Overwrite default arguments with configuration file {default_args.config}") - default_args.config = os.path.join(abs_path, default_args.config) - with open(default_args.config, "r") as f: - cfg = yaml.safe_load(f) - _check_cfgs_in_parser(cfg, parser) - parser.set_defaults(**cfg) - args = parser.parse_args() - - assert not (args.use_parallel and args.num_frames == 1) - - return args - - -def prepare_pipeline(args): - # VAE model initiate and weight loading - print_banner("vae init") - vae_dtype = get_precision(args.vae_precision) - if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): - logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") - state_dict = ms.load_checkpoint(args.ms_checkpoint) - # rm 'network.' prefix - state_dict = dict( - [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() - ) - else: - state_dict = None - kwarg = { - "state_dict": state_dict, - "use_safetensors": True, - "dtype": vae_dtype, - } - vae = ae_wrapper[args.ae](args.ae_path, **kwarg) - vae.vae_scale_factor = ae_stride_config[args.ae] - if args.enable_tiling: - vae.vae.enable_tiling() - vae.vae.tile_overlap_factor = args.tile_overlap_factor - - ## use amp level O2 for causal 3D VAE with bfloat16 or float16 - if vae_dtype == ms.float16: - custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] - else: - custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] - logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}") - vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells) - - vae.set_train(False) - for param in vae.get_parameters(): # freeze vae - param.requires_grad = False - - if args.decode_latents: - print("To decode latents directly, skipped loading text endoers and transformer") - return vae - - # Build text encoders - print_banner("text encoder init") - text_encoder_dtype = get_precision(args.text_encoder_precision) - if 'mt5' in args.text_encoder_name_1: - text_encoder_1, loading_info = MT5EncoderModel.from_pretrained( - args.text_encoder_name_1, - cache_dir=args.cache_dir, - output_loading_info=True, - mindspore_dtype=text_encoder_dtype, - use_safetensors=True - ) - loading_info.pop("unexpected_keys") # decoder weights are ignored - logger.info(f"Loaded MT5 Encoder: {loading_info}") - text_encoder_1 = text_encoder_1.set_train(False) - else: - text_encoder_1 = T5EncoderModel.from_pretrained( - args.text_encoder_name_1, cache_dir=args.cache_dir, - mindspore_dtype=text_encoder_dtype - ).set_train(False) - tokenizer_1 = AutoTokenizer.from_pretrained( - args.text_encoder_name_1, cache_dir=args.cache_dir - ) - - if args.text_encoder_name_2 is not None: - text_encoder_2, loading_info = CLIPTextModelWithProjection.from_pretrained( - args.text_encoder_name_2, - cache_dir=args.cache_dir, - mindspore_dtype=text_encoder_dtype, - output_loading_info=True, - use_safetensors=True, - ) - loading_info.pop("unexpected_keys") # only load text model, ignore vision model - logger.info(f"Loaded CLIP Encoder: {loading_info}") # Note: missed keys when loading open-clip models - text_encoder_2 = text_encoder_2.set_train(False) - tokenizer_2 = AutoTokenizer.from_pretrained( - args.text_encoder_name_2, cache_dir=args.cache_dir - ) - else: - text_encoder_2, tokenizer_2 = None, None - - # Build transformer - print_banner("transformer model init") - FA_dtype = get_precision(args.precision) if get_precision(args.precision) != ms.float32 else ms.bfloat16 - assert args.model_type == "dit", "Currently only suppport model_type as 'dit'@" - if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): - logger.info(f"Initiate from MindSpore checkpoint file {args.ms_checkpoint}") - state_dict = ms.load_checkpoint(args.ms_checkpoint) - # rm 'network.' prefix - state_dict = dict( - [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() - ) - else: - state_dict = None - model_version = args.model_path.split("/")[-1] - if (args.version != 'v1_3') and (model_version.split("x")[0][:3] != "any"): - if int(model_version.split("x")[0]) != args.num_frames: - logger.warning( - f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {model_version.split('x')[0]}" - ) - if int(model_version.split("x")[1][:-1]) != args.height: - logger.warning( - f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}" - ) - elif (args.version == 'v1_3') and (model_version.split("x")[0] == "any93x640x640"): # TODO: currently only release one model - if (args.height % 32 != 0) or (args.width % 32 != 0): - logger.warning( - f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}. The resolution of the inference should be a multiple of 32." - ) - if (args.num_frames - 1) % 4 != 0: - logger.warning( - f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {args.num_frames}. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image)" - ) - - if args.version == 'v1_3': - # TODO - # if args.model_type == 'inpaint' or args.model_type == 'i2v': - # transformer_model = OpenSoraInpaint_v1_3.from_pretrained( - # args.model_path, cache_dir=args.cache_dir, - # device_map=None, mindspore_dtype=weight_dtype - # ).set_train(False) - # else: - transformer_model, logging_info = OpenSoraT2V_v1_3.from_pretrained( - args.model_path, - state_dict=state_dict, - cache_dir=args.cache_dir, - FA_dtype = FA_dtype, - output_loading_info=True - ) - logger.info(logging_info) - elif args.version == 'v1_5': - if args.model_type == 'inpaint' or args.model_type == 'i2v': - raise NotImplementedError('Inpainting model is not available in v1_5') - else: - from opensora.models.diffusion.opensora_v1_5.modeling_opensora import OpenSoraT2V_v1_5 - transformer_model = OpenSoraT2V_v1_5.from_pretrained( - args.model_path, cache_dir=args.cache_dir, - # device_map=None, - mindspore_dtype=weight_dtype - ) - - # Mixed precision - dtype = get_precision(args.precision) - if args.precision in ["fp16", "bf16"]: - if not args.global_bf16: - amp_level = args.amp_level - transformer_model = auto_mixed_precision( - transformer_model, - amp_level=args.amp_level, - dtype=dtype, - custom_fp32_cells=[LayerNorm, Attention, nn.SiLU, nn.GELU, PixArtAlphaCombinedTimestepSizeEmbeddings] - if dtype == ms.float16 - else [ - nn.MaxPool2d, - nn.MaxPool3d, - LayerNorm, - nn.SiLU, - nn.GELU, - PixArtAlphaCombinedTimestepSizeEmbeddings, - ], - ) - logger.info( - f"Set mixed precision to {args.amp_level} with dtype={args.precision}, custom fp32_cells {custom_fp32_cells}" - ) - else: - logger.info(f"Using global bf16. Force model dtype from {dtype} to ms.bfloat16") - dtype = ms.bfloat16 - elif args.precision == "fp32": - amp_level = "O0" - else: - raise ValueError(f"Unsupported precision {args.precision}") - transformer_model = transformer_model.set_train(False) - for param in transformer_model.get_parameters(): # freeze transformer_model - param.requires_grad = False - - # Build scheduler - scheduler = get_scheduler(args) - - # Build inference pipeline - # pipeline_class = OpenSoraInpaintPipeline if args.model_type == 'inpaint' or args.model_type == 'i2v' else OpenSoraPipeline - pipeline_class = OpenSoraPipeline - - pipeline = pipeline_class( - vae=vae, - text_encoder=text_encoder_1, - tokenizer=tokenizer_1, - scheduler=scheduler, - transformer=transformer_model, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - ) - - # Print key info - num_params_vae, num_params_vae_trainable = count_params(vae) - num_params_latte, num_params_latte_trainable = count_params(transformer_model) - num_params = num_params_vae + num_params_latte - num_params_trainable = num_params_vae_trainable + num_params_latte_trainable - key_info = "Key Settings:\n" + "=" * 50 + "\n" - key_info += "\n".join( - [ - f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", - f"Jit level: {args.jit_level}", - f"Num of samples: {len(args.text_prompt)}", - f"Num params: {num_params:,} (latte: {num_params_latte:,}, vae: {num_params_vae:,})", - f"Num trainable params: {num_params_trainable:,}", - f"Transformer dtype: {dtype}", - f"VAE dtype: {vae_dtype}", - f"Text encoder dtype: {text_encoder_dtype}", - f"Sampling steps {args.num_sampling_steps}", - f"Sampling method: {args.sample_method}", - f"CFG guidance scale: {args.guidance_scale}", - f"FA dtype: {FA_dtype}", - f"Inference shape (num_frames x height x width): {args.num_frames}x{args.height}x{args.width}", - ] - ) - key_info += "\n" + "=" * 50 - logger.info(key_info) - - return pipeline - -def run_model_and_save_samples(args, pipeline, rank_id, device_num, caption_refiner_model=None, enhance_video_model=None): - - # Handle input text prompts - print_banner("text prompts loading") - ext = ( - f"{args.video_extension}" if not (args.save_latents or args.decode_latents) else "npy" - ) # save video as gif or save denoised latents as npy files. - ext = "jpg" if args.num_frames == 1 else ext - if not isinstance(args.text_prompt, list): - args.text_prompt = [args.text_prompt] - # if input is a text file, where each line is a caption, load it into a list - if len(args.text_prompt) == 1 and args.text_prompt[0].endswith("txt"): - captions = open(args.text_prompt[0], "r").readlines() - args.text_prompt = [i.strip() for i in captions] - if len(args.text_prompt) == 1 and args.text_prompt[0].endswith("csv"): - captions = pd.read_csv(args.text_prompt[0]) - args.text_prompt = [i.strip() for i in captions["cap"]] - n = len(args.text_prompt) - assert n > 0, "No captions provided" - logger.info(f"Number of prompts: {n}") - logger.info(f"Number of generated samples for each prompt {args.num_videos_per_prompt}") - - # Create dataloader for the captions - csv_file = {"path": [], "cap": []} - for i in range(n): - for i_video in range(args.num_videos_per_prompt): - csv_file["path"].append(f"{i_video}-{args.text_prompt[i].strip()[:100]}.{ext}") - csv_file["cap"].append(args.text_prompt[i]) - temp_dataset_csv = os.path.join(save_dir, "dataset.csv") - pd.DataFrame.from_dict(csv_file).to_csv(temp_dataset_csv, index=False, columns=csv_file.keys()) - - ds_config = dict( - data_file_path=temp_dataset_csv, - tokenizer=None, # tokenizer, - file_column="path", - caption_column="cap", - ) - dataset = create_dataloader( - ds_config, - args.batch_size, - ds_name="text", - num_parallel_workers=12, - max_rowsize=32, - shuffle=False, # be in order - device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), - rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, - drop_remainder=False, - ) - dataset_size = dataset.get_dataset_size() - logger.info(f"Num batches: {dataset_size}") - ds_iter = dataset.create_dict_iterator(1, output_numpy=True) - - # Decode latents directly - if args.decode_latents: - print("Directly decoding latents...") - assert isinstance(pipeline, ae_wrapper[args.ae]) - vae = pipeline - for step, data in tqdm(enumerate(ds_iter), total=dataset_size): - file_paths = data["file_path"] - loaded_latents = [] - for i_sample in range(args.batch_size): - save_fp = os.path.join(save_dir, file_paths[i_sample]) - assert os.path.exists( - save_fp - ), f"{save_fp} does not exist! Please check the npy files under {save_dir} or check if you run `--save_latents` ahead." - loaded_latents.append(np.load(save_fp)) - loaded_latents = ( - np.stack(loaded_latents) if loaded_latents[0].ndim == 4 else np.concatenate(loaded_latents, axis=0) - ) - decode_data = ( - vae.decode(ms.Tensor(loaded_latents)).permute(0, 1, 3, 4, 2).to(ms.float32) - ) # (b t c h w) -> (b t h w c) - decode_data = ms.ops.clip_by_value( - (decode_data + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0 - ).asnumpy() - for i_sample in range(args.batch_size): - save_fp = os.path.join(save_dir, file_paths[i_sample]).replace(".npy", f".{args.video_extension}") - save_video_data = decode_data[i_sample : i_sample + 1] - save_videos(save_video_data, save_fp, loop=0, fps=args.fps) # (b t h w c) - - # Delete files that are no longer needed - if os.path.exists(temp_dataset_csv): - os.remove(temp_dataset_csv) - - if args.decode_latents: - npy_files = glob.glob(os.path.join(save_dir, "*.npy")) - for fp in npy_files: - os.remove(fp) - - # TODO - # if args.model_type == 'inpaint' or args.model_type == 'i2v': - # if not isinstance(args.conditional_pixel_values_path, list): - # args.conditional_pixel_values_path = [args.conditional_pixel_values_path] - # if len(args.conditional_pixel_values_path) == 1 and args.conditional_pixel_values_path[0].endswith('txt'): - # temp = open(args.conditional_pixel_values_path[0], 'r').readlines() - # conditional_pixel_values_path = [i.strip().split(',') for i in temp] - # mask_type = args.mask_type if args.mask_type is not None else None - - # positive_prompt = """ - # high quality, high aesthetic, {} - # """ - # negative_prompt = """ - # nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, - # low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. - # """ - positive_prompt = ( - "(masterpiece), (best quality), (ultra-detailed), {}. emotional, " - + "harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous" - ) - negative_prompt = ( - "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, " - + "extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" - ) - - def generate(step, data, ext, conditional_pixel_values_path=None, mask_type=None): - - # TODO - # if args.caption_refiner is not None: - # if args.model_type != 'inpaint' and args.model_type != 'i2v': - # refine_prompt = caption_refiner_model.get_refiner_output(prompt) - # print(f'\nOrigin prompt: {prompt}\n->\nRefine prompt: {refine_prompt}') - # prompt = refine_prompt - # else: - # # Due to the current use of LLM as the caption refiner, additional content that is not present in the control image will be added. Therefore, caption refiner is not used in this mode. - # print('Caption refiner is not available for inpainting model, use the original prompt...') - # time.sleep(3) - # input_prompt = positive_prompt.format(prompt) - # if args.model_type == 'inpaint' or args.model_type == 'i2v': - # print(f'\nConditional pixel values path: {conditional_pixel_values_path}') - # videos = pipeline( - # conditional_pixel_values_path=conditional_pixel_values_path, - # mask_type=mask_type, - # crop_for_hw=args.crop_for_hw, - # max_hxw=args.max_hxw, - # prompt=input_prompt, - # negative_prompt=negative_prompt, - # num_frames=args.num_frames, - # height=args.height, - # width=args.width, - # num_inference_steps=args.num_sampling_steps, - # guidance_scale=args.guidance_scale, - # num_samples_per_prompt=args.num_samples_per_prompt, - # max_sequence_length=args.max_sequence_length, - # ).videos - # else: - prompt = [x for x in data["caption"]] - file_paths = data["file_path"] - input_prompt = positive_prompt.format(prompt) - videos = ( - pipeline( - input_prompt, - negative_prompt=negative_prompt, - num_frames=args.num_frames, - height=args.height, - width=args.width, - num_inference_steps=args.num_sampling_steps, - guidance_scale=args.guidance_scale, - num_samples_per_prompt=args.num_samples_per_prompt, - output_type="latents" if args.save_latents else "pil", - max_sequence_length=args.max_sequence_length, - ) - .videos.to(ms.float32) - .asnumpy() - ) - # if enhance_video_model is not None: - # # b t h w c - # videos = enhance_video_model.enhance_a_video(videos, input_prompt, 2.0, args.fps, 250) - if step == 0 and profiler is not None: - profiler.stop() - - if get_sequence_parallel_state() and hccl_info.rank % hccl_info.world_size != 0: - pass - else: - # save result - for i_sample in range(args.batch_size): - file_path = os.path.join(save_dir, file_paths[i_sample]) - assert ext in file_path, f"Only support saving as {ext} files, but got {file_path}." - if args.save_latents: - np.save(file_path, videos[i_sample : i_sample + 1]) - else: - if args.num_frames == 1: - ext = "jpg" - image = videos[i_sample, 0] # (b t h w c) -> (h, w, c) - image = (image * 255).round().clip(0, 255).astype(np.uint8) - Image.fromarray(image).save(file_path) - else: - save_video_data = videos[i_sample : i_sample + 1] # (b t h w c) - save_videos(save_video_data, file_path, loop=0, fps=args.fps) - - if args.profile: - profiler = ms.Profiler(output_path="./mem_info", profile_memory=True) - ms.set_context(memory_optimize_level="O0") - ms.set_context(pynative_synchronize=True) - else: - profiler = None - - # Infer - # if args.model_type == 'inpaint' or args.model_type == 'i2v': - # for index, (prompt, cond_path) in enumerate(zip(args.text_prompt, conditional_pixel_values_path)): - # if not args.sp and args.local_rank != -1 and index % args.world_size != args.local_rank: - # continue - # generate(prompt, conditional_pixel_values_path=cond_path, mask_type=mask_type) - # print('completed, please check the saved images and videos') - # else: - for step, data in tqdm(enumerate(ds_iter), total=dataset_size): - generate(step, data, ext) - - - # Delete files that are no longer needed - if os.path.exists(temp_dataset_csv): - os.remove(temp_dataset_csv) - - - if __name__ == "__main__": - args = parse_args() + args = get_args() save_dir = args.save_img_path os.makedirs(save_dir, exist_ok=True) set_logger(name="", output_dir=save_dir) # 1. init environment rank_id, device_num = npu_config.set_npu_env(args) - # rank_id, device_num = init_env( - # args.mode, - # seed=args.seed, - # distributed=args.use_parallel, - # device_target=args.device, - # max_device_memory=args.max_device_memory, - # parallel_mode=args.parallel_mode, - # precision_mode=args.precision_mode, - # global_bf16=args.global_bf16, - # sp_size=args.sp_size, - # jit_level=args.jit_level, - # jit_syntax_level=args.jit_syntax_level, - # ) - + # 2. build models and pipeline - if args.num_frames != 1 and args.enhance_video is not None: + if args.num_frames != 1 and args.enhance_video is not None: #TODO from opensora.sample.VEnhancer.enhance_a_video import VEnhancer enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=args.device) else: enhance_video_model = None - pipeline = prepare_pipeline(args) + + pipeline = prepare_pipeline(args) # build I2V/T2V pipeline - # TODO - # if args.caption_refiner is not None: - # caption_refiner_model = OpenSoraCaptionRefiner(args, dtype, device) - # else: - caption_refiner_model = None + if args.caption_refiner is not None: #TODO: TO TEST + caption_refiner_model = OpenSoraCaptionRefiner(args.caption_refiner, dtype=ms.float16) + else: + caption_refiner_model = None # 3. inference start_time = time.time() - run_model_and_save_samples(args, pipeline, rank_id, device_num, caption_refiner_model, enhance_video_model) + run_model_and_save_samples(args, pipeline, rank_id, device_num, save_dir, caption_refiner_model, enhance_video_model) end_time = time.time() time_cost = end_time - start_time logger.info(f"Inference time cost: {time_cost:0.3f}s") diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py new file mode 100644 index 0000000000..64a383aaf5 --- /dev/null +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -0,0 +1,726 @@ +import argparse +import glob +import logging +import os +import sys +import time + +import numpy as np +import pandas as pd +import yaml +from PIL import Image +from tqdm import tqdm + +import mindspore as ms +from mindspore import nn + + +from mindone.diffusers import ( + DDIMScheduler, DDPMScheduler, PNDMScheduler, + EulerDiscreteScheduler, DPMSolverMultistepScheduler, + HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, + DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler, + DPMSolverSinglestepScheduler, #CogVideoXDDIMScheduler, + FlowMatchEulerDiscreteScheduler + ) + +from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info +from opensora.dataset.text_dataset import create_dataloader +from opensora.utils.message_utils import print_banner +from opensora.utils.ms_utils import init_env +from opensora.utils.utils import _check_cfgs_in_parser, get_precision +from opensora.models.causalvideovae import ae_stride_config, ae_wrapper +# from opensora.sample.caption_refiner import OpenSoraCaptionRefiner +from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate +from examples.opensora_pku.opensora.models.diffusion.opensora.modeling_opensora import LayerNorm, OpenSoraT2V_v1_3 +from examples.opensora_pku.opensora.models.diffusion.opensora.modules import Attention +from opensora.sample.pipeline_opensora import OpenSoraPipeline + +from mindone.diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from mindone.diffusers import ( + DDIMScheduler, DDPMScheduler, PNDMScheduler, + EulerDiscreteScheduler, DPMSolverMultistepScheduler, + HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, + DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler, + DPMSolverSinglestepScheduler, #CogVideoXDDIMScheduler, + FlowMatchEulerDiscreteScheduler + ) +from mindone.transformers import T5EncoderModel, MT5EncoderModel, CLIPTextModelWithProjection +from transformers import AutoTokenizer, MT5Tokenizer + +from mindone.utils.amp import auto_mixed_precision +from mindone.utils.config import str2bool +from mindone.utils.params import count_params +from mindone.visualize.videos import save_videos +from mindone.diffusers.training_utils import set_seed + +logger = logging.getLogger(__name__) + +def get_scheduler(args): + kwargs = dict( + prediction_type=args.prediction_type, + rescale_betas_zero_snr=args.rescale_betas_zero_snr, + timestep_spacing="trailing" if args.rescale_betas_zero_snr else 'leading', + ) + if args.v1_5_scheduler: + kwargs['beta_start'] = 0.00085 + kwargs['beta_end'] = 0.0120 + kwargs['beta_schedule'] = "scaled_linear" + if args.sample_method == 'DDIM': + scheduler_cls = DDIMScheduler + kwargs['clip_sample'] = False + elif args.sample_method == 'EulerDiscrete': + scheduler_cls = EulerDiscreteScheduler + elif args.sample_method == 'DDPM': + scheduler_cls = DDPMScheduler + kwargs['clip_sample'] = False + elif args.sample_method == 'DPMSolverMultistep': + scheduler_cls = DPMSolverMultistepScheduler + elif args.sample_method == 'DPMSolverSinglestep': + scheduler_cls = DPMSolverSinglestepScheduler + elif args.sample_method == 'PNDM': + scheduler_cls = PNDMScheduler + kwargs.pop('rescale_betas_zero_snr', None) + elif args.sample_method == 'HeunDiscrete': ######## + scheduler_cls = HeunDiscreteScheduler + elif args.sample_method == 'EulerAncestralDiscrete': + scheduler_cls = EulerAncestralDiscreteScheduler + elif args.sample_method == 'DEISMultistep': + scheduler_cls = DEISMultistepScheduler + kwargs.pop('rescale_betas_zero_snr', None) + elif args.sample_method == 'KDPM2AncestralDiscrete': ######### + scheduler_cls = KDPM2AncestralDiscreteScheduler + # elif args.sample_method == 'CogVideoX': + # scheduler_cls = CogVideoXDDIMScheduler + elif args.sample_method == 'FlowMatchEulerDiscrete': + scheduler_cls = FlowMatchEulerDiscreteScheduler + kwargs = {} + else: + raise NameError(f'Unsupport sample_method {args.sample_method}') + scheduler = scheduler_cls(**kwargs) + return scheduler + + + +def prepare_pipeline(args): + # VAE model initiate and weight loading + print_banner("vae init") + vae_dtype = get_precision(args.vae_precision) + if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): + logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") + state_dict = ms.load_checkpoint(args.ms_checkpoint) + # rm 'network.' prefix + state_dict = dict( + [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() + ) + else: + state_dict = None + kwarg = { + "state_dict": state_dict, + "use_safetensors": True, + "dtype": vae_dtype, + } + vae = ae_wrapper[args.ae](args.ae_path, **kwarg) + vae.vae_scale_factor = ae_stride_config[args.ae] + if args.enable_tiling: + vae.vae.enable_tiling() + vae.vae.tile_overlap_factor = args.tile_overlap_factor + + ## use amp level O2 for causal 3D VAE with bfloat16 or float16 + if vae_dtype == ms.float16: + custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] + else: + custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] + logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}") + vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells) + + vae.set_train(False) + for param in vae.get_parameters(): # freeze vae + param.requires_grad = False + + if args.decode_latents: + print("To decode latents directly, skipped loading text endoers and transformer") + return vae + + # Build text encoders + print_banner("text encoder init") + text_encoder_dtype = get_precision(args.text_encoder_precision) + if 'mt5' in args.text_encoder_name_1: + text_encoder_1, loading_info = MT5EncoderModel.from_pretrained( + args.text_encoder_name_1, + cache_dir=args.cache_dir, + output_loading_info=True, + mindspore_dtype=text_encoder_dtype, + use_safetensors=True + ) + # loading_info.pop("unexpected_keys") # decoder weights are ignored + # logger.info(f"Loaded MT5 Encoder: {loading_info}") + text_encoder_1 = text_encoder_1.set_train(False) + else: + text_encoder_1 = T5EncoderModel.from_pretrained( + args.text_encoder_name_1, cache_dir=args.cache_dir, + mindspore_dtype=text_encoder_dtype + ).set_train(False) + tokenizer_1 = AutoTokenizer.from_pretrained( + args.text_encoder_name_1, cache_dir=args.cache_dir + ) + + if args.text_encoder_name_2 is not None: + text_encoder_2, loading_info = CLIPTextModelWithProjection.from_pretrained( + args.text_encoder_name_2, + cache_dir=args.cache_dir, + mindspore_dtype=text_encoder_dtype, + output_loading_info=True, + use_safetensors=True, + ) + # loading_info.pop("unexpected_keys") # only load text model, ignore vision model + # loading_info.pop("mising_keys") # Note: missed keys when loading open-clip models + # logger.info(f"Loaded CLIP Encoder: {loading_info}") + text_encoder_2 = text_encoder_2.set_train(False) + tokenizer_2 = AutoTokenizer.from_pretrained( + args.text_encoder_name_2, cache_dir=args.cache_dir + ) + else: + text_encoder_2, tokenizer_2 = None, None + + # Build transformer + print_banner("transformer model init") + FA_dtype = get_precision(args.precision) if get_precision(args.precision) != ms.float32 else ms.bfloat16 + assert args.model_type == "dit", "Currently only suppport model_type as 'dit'@" + if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): + logger.info(f"Initiate from MindSpore checkpoint file {args.ms_checkpoint}") + state_dict = ms.load_checkpoint(args.ms_checkpoint) + # rm 'network.' prefix + state_dict = dict( + [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() + ) + else: + state_dict = None + model_version = args.model_path.split("/")[-1] + if (args.version != 'v1_3') and (model_version.split("x")[0][:3] != "any"): + if int(model_version.split("x")[0]) != args.num_frames: + logger.warning( + f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {model_version.split('x')[0]}" + ) + if int(model_version.split("x")[1][:-1]) != args.height: + logger.warning( + f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}" + ) + elif (args.version == 'v1_3') and (model_version.split("x")[0] == "any93x640x640"): # TODO: currently only release one model + if (args.height % 32 != 0) or (args.width % 32 != 0): + logger.warning( + f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}. The resolution of the inference should be a multiple of 32." + ) + if (args.num_frames - 1) % 4 != 0: + logger.warning( + f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {args.num_frames}. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image)" + ) + # dit_dtype = get_precision(args.precision) + # if dit_dtype == "fp16": # Attention processor cannot convert to fp16 + dit_dtype = None + if args.version == 'v1_3': + # TODO + # if args.model_type == 'inpaint' or args.model_type == 'i2v': + # transformer_model = OpenSoraInpaint_v1_3.from_pretrained( + # args.model_path, cache_dir=args.cache_dir, + # device_map=None, mindspore_dtype=weight_dtype + # ).set_train(False) + # else: + + transformer_model, logging_info = OpenSoraT2V_v1_3.from_pretrained( + args.model_path, + state_dict=state_dict, + cache_dir=args.cache_dir, + # mindspore_dtype=dit_dtype, + FA_dtype = FA_dtype, + output_loading_info=True, + ) + logger.info(logging_info) + elif args.version == 'v1_5': + if args.model_type == 'inpaint' or args.model_type == 'i2v': + raise NotImplementedError('Inpainting model is not available in v1_5') + else: + from opensora.models.diffusion.opensora_v1_5.modeling_opensora import OpenSoraT2V_v1_5 + transformer_model = OpenSoraT2V_v1_5.from_pretrained( + args.model_path, cache_dir=args.cache_dir, + # device_map=None, + mindspore_dtype=weight_dtype + ) + + # Mixed precision + dtype = get_precision(args.precision) + if args.precision in ["fp16", "bf16"]: + if not args.global_bf16: + amp_level = args.amp_level + if dtype == ms.float16: + custom_fp32_cells=[LayerNorm, Attention, nn.SiLU, nn.GELU, PixArtAlphaCombinedTimestepSizeEmbeddings] + else: + custom_fp32_cells= [ + nn.MaxPool2d, + nn.MaxPool3d, # do not support bf16 + LayerNorm, + nn.SiLU, + nn.GELU, + PixArtAlphaCombinedTimestepSizeEmbeddings, + ] + transformer_model = auto_mixed_precision( + transformer_model, + amp_level=args.amp_level, + dtype=dtype, + custom_fp32_cells=custom_fp32_cells + ) + logger.info( + f"Set mixed precision to {args.amp_level} with dtype={args.precision}, custom fp32_cells {custom_fp32_cells}" + ) + else: + logger.info(f"Using global bf16. Force model dtype from {dtype} to ms.bfloat16") + dtype = ms.bfloat16 + elif args.precision == "fp32": + amp_level = "O0" + else: + raise ValueError(f"Unsupported precision {args.precision}") + transformer_model = transformer_model.set_train(False) + for param in transformer_model.get_parameters(): # freeze transformer_model + param.requires_grad = False + + # Build scheduler + scheduler = get_scheduler(args) + + # Build inference pipeline + # pipeline_class = OpenSoraInpaintPipeline if args.model_type == 'inpaint' or args.model_type == 'i2v' else OpenSoraPipeline + pipeline_class = OpenSoraPipeline + + pipeline = pipeline_class( + vae=vae, + text_encoder=text_encoder_1, + tokenizer=tokenizer_1, + scheduler=scheduler, + transformer=transformer_model, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + if args.save_memory: #TODO: Susan comment: I am not sure yet + print('enable_model_cpu_offload AND enable_sequential_cpu_offload AND enable_tiling') + pipeline.enable_model_cpu_offload() + pipeline.enable_sequential_cpu_offload() + if not args.enable_tiling: + vae.vae.enable_tiling() + vae.vae.t_chunk_enc = 8 + vae.vae.t_chunk_dec = vae.vae.t_chunk_enc // 2 + + # Print key info + num_params_vae, num_params_vae_trainable = count_params(vae) + num_params_latte, num_params_latte_trainable = count_params(transformer_model) + num_params = num_params_vae + num_params_latte + num_params_trainable = num_params_vae_trainable + num_params_latte_trainable + key_info = "Key Settings:\n" + "=" * 50 + "\n" + key_info += "\n".join( + [ + f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", + f"Jit level: {args.jit_level}", + f"Num of samples: {len(args.text_prompt)}", + f"Num params: {num_params:,} (latte: {num_params_latte:,}, vae: {num_params_vae:,})", + f"Num trainable params: {num_params_trainable:,}", + f"Transformer dtype: {dtype}", + f"VAE dtype: {vae_dtype}", + f"Text encoder dtype: {text_encoder_dtype}", + f"Sampling steps {args.num_sampling_steps}", + f"Sampling method: {args.sample_method}", + f"CFG guidance scale: {args.guidance_scale}", + f"FA dtype: {FA_dtype}", + f"Inference shape (num_frames x height x width): {args.num_frames}x{args.height}x{args.width}", + ] + ) + key_info += "\n" + "=" * 50 + logger.info(key_info) + + return pipeline + +## See npu_config.py set_npu_env() +# def init_npu_env(args): +# local_rank = int(os.getenv('RANK', 0)) +# world_size = int(os.getenv('WORLD_SIZE', 1)) +# args.local_rank = local_rank +# args.world_size = world_size +# torch_npu.npu.set_device(local_rank) +# dist.init_process_group( +# backend='hccl', init_method='env://', +# world_size=world_size, rank=local_rank +# ) +# if args.sp: +# initialize_sequence_parallel_state(world_size) +# return args + + +def run_model_and_save_samples(args, pipeline, rank_id, device_num, save_dir, caption_refiner_model=None, enhance_video_model=None): + if args.seed is not None: + set_seed(args.seed, rank=rank_id) + + # Handle input text prompts + print_banner("text prompts loading") + ext = ( + f"{args.video_extension}" if not (args.save_latents or args.decode_latents) else "npy" + ) # save video as gif or save denoised latents as npy files. + ext = "jpg" if args.num_frames == 1 else ext + if not isinstance(args.text_prompt, list): + args.text_prompt = [args.text_prompt] + # if input is a text file, where each line is a caption, load it into a list + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith("txt"): + captions = open(args.text_prompt[0], "r").readlines() + args.text_prompt = [i.strip() for i in captions] + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith("csv"): + captions = pd.read_csv(args.text_prompt[0]) + args.text_prompt = [i.strip() for i in captions["cap"]] + n = len(args.text_prompt) + assert n > 0, "No captions provided" + logger.info(f"Number of prompts: {n}") + logger.info(f"Number of generated samples for each prompt {args.num_videos_per_prompt}") + + # Create dataloader for the captions + csv_file = {"path": [], "cap": []} + for i in range(n): + for i_video in range(args.num_videos_per_prompt): + csv_file["path"].append(f"{i_video}-{args.text_prompt[i].strip()[:100]}.{ext}") + csv_file["cap"].append(args.text_prompt[i]) + temp_dataset_csv = os.path.join(save_dir, "dataset.csv") + pd.DataFrame.from_dict(csv_file).to_csv(temp_dataset_csv, index=False, columns=csv_file.keys()) + + ds_config = dict( + data_file_path=temp_dataset_csv, + tokenizer=None, # tokenizer, + file_column="path", + caption_column="cap", + ) + dataset = create_dataloader( + ds_config, + args.batch_size, + ds_name="text", + num_parallel_workers=12, + max_rowsize=32, + shuffle=False, # be in order + device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), + rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, + drop_remainder=False, + ) + dataset_size = dataset.get_dataset_size() + logger.info(f"Num batches: {dataset_size}") + ds_iter = dataset.create_dict_iterator(1, output_numpy=True) + + # Decode latents directly + if args.decode_latents: + assert isinstance(pipeline, ae_wrapper[args.ae]) + vae = pipeline + for step, data in tqdm(enumerate(ds_iter), total=dataset_size): + file_paths = data["file_path"] + loaded_latents = [] + for i_sample in range(args.batch_size): + save_fp = os.path.join(save_dir, file_paths[i_sample]) + assert os.path.exists( + save_fp + ), f"{save_fp} does not exist! Please check the npy files under {save_dir} or check if you run `--save_latents` ahead." + loaded_latents.append(np.load(save_fp)) + loaded_latents = ( + np.stack(loaded_latents) if loaded_latents[0].ndim == 4 else np.concatenate(loaded_latents, axis=0) + ) + decode_data = ( + vae.decode(ms.Tensor(loaded_latents)).permute(0, 1, 3, 4, 2).to(ms.float32) + ) # (b t c h w) -> (b t h w c) + decode_data = ms.ops.clip_by_value( + (decode_data + 1.0) / 2.0, clip_value_min=0.0, clip_value_max=1.0 + ).asnumpy() + for i_sample in range(args.batch_size): + save_fp = os.path.join(save_dir, file_paths[i_sample]).replace(".npy", f".{args.video_extension}") + save_video_data = decode_data[i_sample : i_sample + 1] + save_videos(save_video_data, save_fp, loop=0, fps=args.fps) # (b t h w c) + + # Delete files that are no longer needed + if os.path.exists(temp_dataset_csv): + os.remove(temp_dataset_csv) + + if args.decode_latents: + npy_files = glob.glob(os.path.join(save_dir, "*.npy")) + for fp in npy_files: + os.remove(fp) + + # TODO + # if args.model_type == 'inpaint' or args.model_type == 'i2v': + # if not isinstance(args.conditional_pixel_values_path, list): + # args.conditional_pixel_values_path = [args.conditional_pixel_values_path] + # if len(args.conditional_pixel_values_path) == 1 and args.conditional_pixel_values_path[0].endswith('txt'): + # temp = open(args.conditional_pixel_values_path[0], 'r').readlines() + # conditional_pixel_values_path = [i.strip().split(',') for i in temp] + # mask_type = args.mask_type if args.mask_type is not None else None + + positive_prompt = """ + high quality, high aesthetic, {} + """ + negative_prompt = """ + nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, + low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. + """ + # positive_prompt = ( + # "(masterpiece), (best quality), (ultra-detailed), {}. emotional, " + # + "harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, sharp focus, high budget, cinemascope, moody, epic, gorgeous" + # ) + # negative_prompt = ( + # "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, " + # + "extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" + # ) + + def generate(step, data, ext, conditional_pixel_values_path=None, mask_type=None): + + + if args.caption_refiner is not None: + if args.model_type != 'inpaint' and args.model_type != 'i2v': + refine_prompt = caption_refiner_model.get_refiner_output(prompt) + print(f'\nOrigin prompt: {prompt}\n->\nRefine prompt: {refine_prompt}') + prompt = refine_prompt + else: + # Due to the current use of LLM as the caption refiner, additional content that is not present in the control image will be added. Therefore, caption refiner is not used in this mode. + print('Caption refiner is not available for inpainting model, use the original prompt...') + time.sleep(3) + # TODO + # input_prompt = positive_prompt.format(prompt) + # if args.model_type == 'inpaint' or args.model_type == 'i2v': + # print(f'\nConditional pixel values path: {conditional_pixel_values_path}') + # videos = pipeline( + # conditional_pixel_values_path=conditional_pixel_values_path, + # mask_type=mask_type, + # crop_for_hw=args.crop_for_hw, + # max_hxw=args.max_hxw, + # prompt=input_prompt, + # negative_prompt=negative_prompt, + # num_frames=args.num_frames, + # height=args.height, + # width=args.width, + # num_inference_steps=args.num_sampling_steps, + # guidance_scale=args.guidance_scale, + # num_samples_per_prompt=args.num_samples_per_prompt, + # max_sequence_length=args.max_sequence_length, + # ).videos + # else: + prompt = [x for x in data["caption"]] + file_paths = data["file_path"] + input_prompt = positive_prompt.format(prompt[0]) # remove "[]" + saved_prompt1_dict = None + + videos = ( + pipeline( + input_prompt, + negative_prompt=negative_prompt, + num_frames=args.num_frames, + height=args.height, + width=args.width, + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale, + num_samples_per_prompt=args.num_samples_per_prompt, + output_type="latents" if args.save_latents else "pil", + max_sequence_length=args.max_sequence_length, + ) + .videos.to(ms.float32) + .asnumpy() + ) + # if enhance_video_model is not None: + # # b t h w c + # videos = enhance_video_model.enhance_a_video(videos, input_prompt, 2.0, args.fps, 250) + if step == 0 and profiler is not None: + profiler.stop() + + if get_sequence_parallel_state() and hccl_info.rank % hccl_info.world_size != 0: + pass + else: + # save result + for i_sample in range(args.batch_size): + file_path = os.path.join(save_dir, file_paths[i_sample]) + assert ext in file_path, f"Only support saving as {ext} files, but got {file_path}." + if args.save_latents: + np.save(file_path, videos[i_sample : i_sample + 1]) + else: + if args.num_frames == 1: + ext = "jpg" + image = videos[i_sample, 0] # (b t h w c) -> (h, w, c) + image = (image * 255).round().clip(0, 255).astype(np.uint8) + Image.fromarray(image).save(file_path) + else: + save_video_data = videos[i_sample : i_sample + 1] # (b t h w c) + save_videos(save_video_data, file_path, loop=0, fps=args.fps) + + if args.profile: + profiler = ms.Profiler(output_path="./mem_info", profile_memory=True) + ms.set_context(memory_optimize_level="O0") + ms.set_context(pynative_synchronize=True) + else: + profiler = None + + # Infer + # if args.model_type == 'inpaint' or args.model_type == 'i2v': + # for index, (prompt, cond_path) in enumerate(zip(args.text_prompt, conditional_pixel_values_path)): + # if not args.sp and args.local_rank != -1 and index % args.world_size != args.local_rank: + # continue + # generate(prompt, conditional_pixel_values_path=cond_path, mask_type=mask_type) + # print('completed, please check the saved images and videos') + # else: + for step, data in tqdm(enumerate(ds_iter), total=dataset_size): + generate(step, data, ext) + break # TODO: debug use, delete later + + + # Delete files that are no longer needed + if os.path.exists(temp_dataset_csv): + os.remove(temp_dataset_csv) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--version", type=str, default='v1_3', choices=['v1_3', 'v1_5']) + parser.add_argument("--caption_refiner", type=str, default=None, help="caption refiner model path") + parser.add_argument("--enhance_video", type=str, default=None) + parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl', help="google/mt5-xxl, DeepFloyd/t5-v1_1-xxl") + parser.add_argument("--text_encoder_name_2", type=str, default=None, help=" openai/clip-vit-large-patch14, (laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)") + parser.add_argument("--num_samples_per_prompt", type=int, default=1) + parser.add_argument('--refine_caption', action='store_true') + # parser.add_argument('--compile', action='store_true') + parser.add_argument("--prediction_type", type=str, default='epsilon', help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.") + parser.add_argument('--rescale_betas_zero_snr', action='store_true') + # parser.add_argument('--local_rank', type=int, default=-1) + # parser.add_argument('--world_size', type=int, default=1) + # parser.add_argument('--sp', action='store_true') + parser.add_argument('--v1_5_scheduler', action='store_true') + parser.add_argument('--conditional_pixel_values_path', type=str, default=None) + parser.add_argument('--mask_type', type=str, default=None) + parser.add_argument('--crop_for_hw', action='store_true') + parser.add_argument('--max_hxw', type=int, default=236544) #236544=512x462???? + + parser.add_argument( + "--config", + "-c", + default="", + type=str, + help="path to load a config yaml file that describes the setting which will override the default arguments", + ) + parser.add_argument("--model_path", type=str, default="LanguageBind/Open-Sora-Plan-v1.3.0") + parser.add_argument( + "--ms_checkpoint", + type=str, + default=None, + help="If not provided, will search for ckpt file under `model_path`" + "If provided, will use this pretrained ckpt path.", + ) + parser.add_argument("--num_frames", type=int, default=1) + parser.add_argument("--height", type=int, default=512) + parser.add_argument("--width", type=int, default=512) + parser.add_argument("--ae", type=str, default="CausalVAEModel_4x8x8") + parser.add_argument("--ae_path", type=str, default="CausalVAEModel_4x8x8") + parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") + + parser.add_argument("--text_encoder_name", type=str, default="DeepFloyd/t5-v1_1-xxl") + parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") + + parser.add_argument("--guidance_scale", type=float, default=7.5, help="the scale for classifier-free guidance") + parser.add_argument("--max_sequence_length", type=int, default=512, help="the maximum text tokens length") + + parser.add_argument("--sample_method", type=str, default="PNDM") + parser.add_argument("--num_sampling_steps", type=int, default=50, help="Diffusion Sampling Steps") + parser.add_argument("--fps", type=int, default=24) + parser.add_argument( + "--text_prompt", + type=str, + nargs="+", + help="A list of text prompts to be generated with. Also allow input a txt file or csv file.", + ) + parser.add_argument("--tile_overlap_factor", type=float, default=0.25) + + parser.add_argument("--enable_tiling", action="store_true", help="whether to use vae tiling to save memory") + parser.add_argument("--model_3d", action="store_true") + parser.add_argument("--udit", action="store_true") + parser.add_argument("--save_memory", action="store_true") + parser.add_argument("--batch_size", default=1, type=int, help="batch size for dataloader") + # MS new args + parser.add_argument("--device", type=str, default="Ascend", help="Ascend or GPU") + parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") + parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") + parser.add_argument( + "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data, optim" + ) + parser.add_argument("--jit_level", default="O0", help="Set jit level: # O0: KBK, O1:DVM, O2: GE") + parser.add_argument( + "--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax" + ) + parser.add_argument("--seed", type=int, default=42, help="Inference seed") + + parser.add_argument( + "--precision", + default="bf16", + type=str, + choices=["bf16", "fp16", "fp32"], + help="what data type to use for latte. Default is `fp16`, which corresponds to ms.float16", + ) + parser.add_argument( + "--global_bf16", action="store_true", help="whether to enable gloabal bf16 for diffusion model training." + ) + parser.add_argument( + "--vae_precision", + default="fp16", + type=str, + choices=["bf16", "fp16"], + help="what data type to use for vae. Default is `bf16`, which corresponds to ms.bfloat16", + ) + parser.add_argument( + "--vae_keep_gn_fp32", + default=False, + type=str2bool, + help="whether keep GroupNorm in fp32. Defaults to False in inference mode. If training vae, better set it to True", + ) + parser.add_argument( + "--text_encoder_precision", + default="fp16", + type=str, + choices=["bf16", "fp16"], + help="what data type to use for T5 text encoder. Default is `bf16`, which corresponds to ms.bfloat16", + ) + parser.add_argument( + "--amp_level", type=str, default="O2", help="Set the amp level for the transformer model. Defaults to O2." + ) + parser.add_argument( + "--precision_mode", + default=None, + type=str, + help="If specified, set the precision mode for Ascend configurations.", + ) + parser.add_argument( + "--num_videos_per_prompt", type=int, default=1, help="the number of images to be generated for each prompt" + ) + parser.add_argument( + "--save_latents", + action="store_true", + help="Whether to save latents (before vae decoding) instead of video files.", + ) + parser.add_argument( + "--decode_latents", + action="store_true", + help="whether to load the existing latents saved in npy files and run vae decoding", + ) + parser.add_argument( + "--video_extension", default="mp4", choices=["gif", "mp4"], help="The file extension to save videos" + ) + parser.add_argument("--model_type", type=str, default="dit", choices=["dit", "udit", "latte", 't2v', 'inpaint', 'i2v']) + parser.add_argument("--cache_dir", type=str, default="./") + parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not") + + default_args = parser.parse_args() + abs_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "")) + if default_args.config: + logger.info(f"Overwrite default arguments with configuration file {default_args.config}") + default_args.config = os.path.join(abs_path, default_args.config) + with open(default_args.config, "r") as f: + cfg = yaml.safe_load(f) + _check_cfgs_in_parser(cfg, parser) + parser.set_defaults(**cfg) + args = parser.parse_args() + + assert not (args.use_parallel and args.num_frames == 1) + + return args diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_debug.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_debug.sh new file mode 100644 index 0000000000..b4cc745ef0 --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_debug.sh @@ -0,0 +1,36 @@ +# Quick debug for DiT config: +# - 1 NPU/GPU +# - fewer frames: 29 +# - small and uncommon resolution: 352x640 +# - fps: 24 +# - precision: bf16 (Some exceptions ref to sample_utils.py. Torch ver doesn't share, always uses fp16. ) + +# Debug first prompt only: +# "A young man at his 20s is sitting on a piece of cloud in the sky, reading a book." + +# To use: +# change model_path, text_encoder_name_1, ae_path, save_img_path before running the script. + +export DEVICE_ID=0 +python opensora/sample/sample.py \ + --model_path /home_host/susan/workspace/checkpoints/LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ + --version v1_3 \ + --num_frames 29 \ + --height 352 \ + --width 640 \ + --text_encoder_name_1 /home_host/susan/workspace/checkpoints/google/mt5-xxl \ + --text_prompt examples/prompt_list_0.txt \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path /home_host/susan/workspace/checkpoints/LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --save_img_path "./sample_videos/prompt_list_0_29x640_mt5_bf16_debug" \ + --fps 24 \ + --guidance_scale 7.5 \ + --num_sampling_steps 100 \ + --enable_tiling \ + --max_sequence_length 512 \ + --sample_method EulerAncestralDiscrete \ + --seed 1234 \ + --num_samples_per_prompt 1 \ + --rescale_betas_zero_snr \ + --prediction_type "v_prediction" \ + --mode 1 --precision bf16 \ No newline at end of file diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh index 665483b0ca..cf3373c426 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh @@ -9,11 +9,10 @@ python opensora/sample/sample.py \ --height 704 \ --width 1280 \ --text_encoder_name_1 google/mt5-xxl \ - --text_encoder_name_2 laion/CLIP-ViT-bigG-14-laion2B-39B-b160k \ --text_prompt examples/prompt_list_0.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/prompt_list_0_29x1280_mt5_openclip" \ + --save_img_path "./sample_videos/prompt_list_0_29x1280" \ --fps 24 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x480p.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x480p.sh index a78e05ce98..9b061c59e9 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x480p.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x480p.sh @@ -1,14 +1,14 @@ export DEVICE_ID=0 -python opensora/sample/sample_t2v.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.2.0/29x480p \ +python opensora/sample/sample_v1_3.py \ + --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ + --version v1_3 \ --num_frames 29 \ --height 480 \ --width 640 \ - --cache_dir "./" \ - --text_encoder_name google/mt5-xxl \ + --text_encoder_name_1 google/mt5-xxl \ --text_prompt examples/prompt_list_0.txt \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae\ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --save_img_path "./sample_videos/prompt_list_0_29x480p" \ --fps 24 \ --guidance_scale 7.5 \ @@ -16,4 +16,7 @@ python opensora/sample/sample_t2v.py \ --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ - --model_type "dit" \ + --num_samples_per_prompt 1 \ + --rescale_betas_zero_snr \ + --prediction_type "v_prediction" \ + --mode 1 diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_1texenc.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh similarity index 95% rename from examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_1texenc.sh rename to examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh index 4f28620f74..0d2b17e057 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_1texenc.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh @@ -2,7 +2,7 @@ # So keep the resolution of the inference a multiple of 32. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image). export DEVICE_ID=0 -python opensora/sample/sample_v1_3.py \ +python opensora/sample/sample.py \ --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ --num_frames 93 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh index 513a9c3925..5794be27e1 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh @@ -1,5 +1,8 @@ +# This script is for futher function +# Does not work yet for 2nd text encoder + export DEVICE_ID=0 -python opensora/sample/sample_v1_3.py \ +python opensora/sample/sample.py \ --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ --num_frames 93 \ diff --git a/examples/opensora_pku/torch_intermediate_states/readme_load_states.md b/examples/opensora_pku/torch_intermediate_states/readme_load_states.md new file mode 100644 index 0000000000..89d923b752 --- /dev/null +++ b/examples/opensora_pku/torch_intermediate_states/readme_load_states.md @@ -0,0 +1,210 @@ +### Updated files +- opensora/utils/sample.utils.py (NEW) +- opensora/sample/sample.py <- opensora/sample/sample_t2v.py +- opensora/sample/pipeline_opensora.py + +- opensora/models/diffusion/common.py <- opensora/models/diffusion/opensora/rope.py +- opensora/models/diffusion/opensora/modeling_opensora.py +- opensora/models/diffusion/opensora/modules.py + +### Debugging script +scripts/text_condition/single-device/sample_debug.sh + +### Intermediate dicts to load +Details of saving intermediate states in Pytorch Version opensora/models/diffusion/modeling_opensora.py forward(): + +Note: I only save them in first step of denoising. + +```python +def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + **kwargs, + ): + ##################################### + ## !!!SAVE `input parameters + np.save("./hidden_states_input.npy", hidden_states.float().cpu().numpy()) + np.save("./timestep_input.npy", timestep.float().cpu().numpy()) + np.save("./encoder_hidden_states_input.npy", encoder_hidden_states.float().cpu().numpy()) + np.save("./attention_mask_input.npy", attention_mask.float().cpu().numpy()) + np.save("./encoder_attention_mask_input.npy", encoder_attention_mask.float().cpu().numpy()) + ##################################### + + batch_size, c, frame, h, w = hidden_states.shape + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 4: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + # b, frame, h, w -> a video + # b, 1, h, w -> only images + attention_mask = attention_mask.to(self.dtype) + + attention_mask = attention_mask.unsqueeze(1) # b 1 t h w + attention_mask = F.max_pool3d( + attention_mask, + kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), + stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size) + ) + attention_mask = rearrange(attention_mask, 'b 1 t h w -> (b 1) 1 (t h w)') + attention_mask = (1 - attention_mask.bool().to(self.dtype)) * -10000.0 + + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: + # b, 1, l + encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + + ##################################### + ## !!! SAVE `masks` after conversion + # Note: they used "0" as True, "-10000" as False, and masks have transposed dimension + # I do not suggest to load these masks in Mindspore version + np.save("./attention_mask_converted.npy", attention_mask.float().cpu().numpy()) + np.save("./encoder_attention_mask_converted.npy", encoder_attention_mask.float().cpu().numpy()) + ##################################### + + # 1. Input + frame = ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t # patchfy + height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size + + hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, batch_size, frame + ) + + ##################################### + ## !!! SAVE states after `_operate_on_patched_inputs` + np.save("./hidden_states_operate_on_patched_inputs.npy", hidden_states.float().cpu().numpy()) + np.save("./encoder_hidden_states_operate_on_patched_inputs.npy", encoder_hidden_states.float().cpu().numpy()) + np.save("./timestep_operate_on_patched_inputs.npy", timestep.float().cpu().numpy()) + np.save("./embedded_timestep_operate_on_patched_inputs.npy", embedded_timestep.float().cpu().numpy()) + ##################################### + + # To + # x (t*h*w b d) or (t//sp*h*w b d) + # cond_1 (l b d) or (l//sp b d) + hidden_states = rearrange(hidden_states, 'b s h -> s b h', b=batch_size).contiguous() + encoder_hidden_states = rearrange(encoder_hidden_states, 'b s h -> s b h', b=batch_size).contiguous() + timestep = timestep.view(batch_size, 6, -1).transpose(0, 1).contiguous() + + sparse_mask = {} + if npu_config is None: + if get_sequence_parallel_state(): + head_num = self.config.num_attention_heads // nccl_info.world_size + else: + head_num = self.config.num_attention_heads + else: + head_num = None + for sparse_n in [1, 4]: + sparse_mask[sparse_n] = Attention.prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num) + ##################################### + ## !!! SAVE sparse masks + # Note: they used "0" as True, "-10000" as False, and masks have transposed dimension + # I do not suggest to load these masks in Mindspore version + attention_mask_sparse_1_False, encoder_attention_mask_sparse_1_False = sparse_mask[1][False] # mask_sparse_1d + attention_mask_sparse_1_True, encoder_attention_mask_sparse_1_True = sparse_mask[1][True] # mask_sparse_1d_group + attention_mask_sparse_4_False, encoder_attention_mask_sparse_4_False = sparse_mask[4][False] # mask_sparse_1d + attention_mask_sparse_4_True, encoder_attention_mask_sparse_4_True = sparse_mask[4][True] # sparse_1d_group + np.save("./attention_mask_sparse_1_False.npy", attention_mask_sparse_1_False.float().cpu().numpy()) + np.save("./encoder_attention_mask_sparse_1_False.npy", encoder_attention_mask_sparse_1_False.float().cpu().numpy()) + np.save("./attention_mask_sparse_1_True.npy", attention_mask_sparse_1_True.float().cpu().numpy()) + np.save("./encoder_attention_mask_sparse_1_True.npy", encoder_attention_mask_sparse_1_True.float().cpu().numpy()) + np.save("./attention_mask_sparse_4_False.npy", attention_mask_sparse_4_False.float().cpu().numpy()) + np.save("./encoder_attention_mask_sparse_4_False.npy", encoder_attention_mask_sparse_4_False.float().cpu().numpy()) + np.save("./attention_mask_sparse_4_True.npy", attention_mask_sparse_4_True.float().cpu().numpy()) + np.save("./encoder_attention_mask_sparse_4_True.npy", encoder_attention_mask_sparse_4_True.float().cpu().numpy()) + ##################################### + + + # 2. Blocks + ##################################### + # !!! SAVE initial input states + np.save(f"./hidden_states_before_block.npy", hidden_states.float().cpu().numpy()) + ##################################### + for i, block in enumerate(self.transformer_blocks): + if i > 1 and i < 30: + attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][block.attn1.processor.sparse_group] + else: + attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] + + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + frame, + height, + width, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + frame=frame, + height=height, + width=width, + ) + ##################################### + # !!! SAVE updated states + np.save(f"./hidden_states_{i}_block.npy", hidden_states.float().cpu().numpy()) + ##################################### + + + # To (b, t*h*w, h) or (b, t//sp*h*w, h) + hidden_states = rearrange(hidden_states, 's b h -> b s h', b=batch_size).contiguous() + + # 3. Output + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep, + embedded_timestep=embedded_timestep, + num_frames=frame, + height=height, + width=width, + ) # b c t h w + + ##################################### + #!!! SAVE output hidden states + np.save("./hidden_states_output.npy", output.float().cpu().numpy()) + ##################################### + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + +``` \ No newline at end of file From be4e1d26513c9c4c05a4ae95503e72baeadfb95c Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Mon, 4 Nov 2024 18:15:32 +0800 Subject: [PATCH 004/133] improve acc of PatchEmbed2D --- .../opensora/models/diffusion/opensora/modeling_opensora.py | 2 +- examples/opensora_pku/opensora/utils/sample_utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index e55a7b8c23..d07c46a107 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -268,7 +268,7 @@ def construct( # return_dict: bool = True, **kwargs, ): - dtype = ms.float16 + dtype = self.dtype # debug use batch_size, c, frame, h, w = hidden_states.shape # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index 64a383aaf5..d08a0dfcba 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -35,6 +35,7 @@ from examples.opensora_pku.opensora.models.diffusion.opensora.modeling_opensora import LayerNorm, OpenSoraT2V_v1_3 from examples.opensora_pku.opensora.models.diffusion.opensora.modules import Attention from opensora.sample.pipeline_opensora import OpenSoraPipeline +from opensora.models.diffusion.common import PatchEmbed2D from mindone.diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings from mindone.diffusers import ( @@ -253,11 +254,12 @@ def prepare_pipeline(args): if not args.global_bf16: amp_level = args.amp_level if dtype == ms.float16: - custom_fp32_cells=[LayerNorm, Attention, nn.SiLU, nn.GELU, PixArtAlphaCombinedTimestepSizeEmbeddings] + custom_fp32_cells=[LayerNorm, Attention, PatchEmbed2D, nn.SiLU, nn.GELU, PixArtAlphaCombinedTimestepSizeEmbeddings] else: custom_fp32_cells= [ nn.MaxPool2d, nn.MaxPool3d, # do not support bf16 + PatchEmbed2D, # low accuracy if using bf16 LayerNorm, nn.SiLU, nn.GELU, From 62f175d9644d30b0878fe6abf4c9c57e825d8087 Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Wed, 6 Nov 2024 14:39:36 +0800 Subject: [PATCH 005/133] align structure, fix bugs --- .../opensora/models/diffusion/common.py | 4 +- .../diffusion/opensora/modeling_opensora.py | 33 ++- .../models/diffusion/opensora/modules.py | 203 +++--------------- .../opensora/utils/sample_utils.py | 5 - 4 files changed, 51 insertions(+), 194 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/common.py b/examples/opensora_pku/opensora/models/diffusion/common.py index e37fb528fb..a5f4a47125 100644 --- a/examples/opensora_pku/opensora/models/diffusion/common.py +++ b/examples/opensora_pku/opensora/models/diffusion/common.py @@ -21,13 +21,13 @@ def __init__( super().__init__() self.proj = nn.Conv2d( in_channels, embed_dim, - kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), has_bias=bias + kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), has_bias=bias, pad_mode="pad" ) def construct(self, latent): b, c, t, h, w = latent.shape # b, c=in_channels, t, h, w # b c t h w -> (b t) c h w - latent = latent.permute(0, 2, 1, 3, 4).reshape(b*t, c, h, w) # b*t, c, h, w + latent = latent.swapaxes(1, 2).reshape(b*t, c, h, w) # b*t, c, h, w latent = self.proj(latent) # b*t, embed_dim, h, w # (b t) c h w -> b (t h w) c _, c, h, w = latent.shape diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index d07c46a107..e8a5546c46 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -19,6 +19,7 @@ from opensora.models.diffusion.opensora.modules import BasicTransformerBlock, LayerNorm, Attention from opensora.models.diffusion.common import PatchEmbed2D +from opensora.npu_config import npu_config class OpenSoraT2V_v1_3(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @@ -52,7 +53,6 @@ def __init__( sparse1d: bool = False, sparse_n: int = 2, - attention_mode: str = "xformers", #NEW use_recompute=False, #NEW FA_dtype=ms.bfloat16, #NEW num_no_recompute: int = 0, #NEW @@ -60,11 +60,10 @@ def __init__( super().__init__() # Set some common variables used across the board. self.out_channels = in_channels if out_channels is None else out_channels - self.config.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim + self.config.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim #24*96=2304 self.gradient_checkpointing = use_recompute #NEW self.use_recompute = use_recompute #NEW self.FA_dtype = FA_dtype #NEW - self.attention_mode = attention_mode #NEW self._init_patched_inputs() def _init_patched_inputs(self): @@ -81,9 +80,9 @@ def _init_patched_inputs(self): ) self.pos_embed = PatchEmbed2D( - patch_size=self.config.patch_size, - in_channels=self.config.in_channels, - embed_dim=self.config.hidden_size, + patch_size=self.config.patch_size, #2 + in_channels=self.config.in_channels, #8 + embed_dim=self.config.hidden_size, #2304 ) self.transformer_blocks = nn.CellList( [ @@ -104,6 +103,7 @@ def _init_patched_inputs(self): sparse1d=self.config.sparse1d if i > 1 and i < 30 else False, sparse_n=self.config.sparse_n, sparse_group=i % 2 == 1, + FA_dtype=self.FA_dtype ) for i in range(self.config.num_layers) ] @@ -254,8 +254,8 @@ def _set_gradient_checkpointing(self, module, value=False): def get_attention_mask(self, attention_mask): if attention_mask is not None: - if self.attention_mode != "math": - attention_mask = attention_mask.to(ms.bool_) + if not npu_config.enable_FA: + attention_mask = attention_mask.to(ms.bool_) # use bool for sdpa return attention_mask def construct( @@ -265,10 +265,9 @@ def construct( encoder_hidden_states: Optional[ms.Tensor] = None, attention_mask: Optional[ms.Tensor] = None, encoder_attention_mask: Optional[ms.Tensor] = None, - # return_dict: bool = True, + return_dict: bool = True, **kwargs, ): - dtype = self.dtype # debug use batch_size, c, frame, h, w = hidden_states.shape # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. @@ -283,8 +282,6 @@ def construct( if attention_mask is not None and attention_mask.ndim == 4: # assume that mask is expressed as: # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) # b, frame, h, w -> a video # b, 1, h, w -> only images attention_mask = attention_mask.to(self.dtype) @@ -294,14 +291,12 @@ def construct( # b 1 t h w -> (b 1) 1 (t h w) attention_mask = attention_mask.reshape(batch_size, 1, -1) - # attention_mask = (1 - attention_mask.bool().to(self.dtype)) * -10000.0 #TODO: TBD - attention_mask = self.get_attention_mask(attention_mask) # use bool mask for FA + attention_mask = self.get_attention_mask(attention_mask) # if use bool mask # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # b, 1, l - # encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 - encoder_attention_mask = self.get_attention_mask(encoder_attention_mask) # use bool mask for FA + encoder_attention_mask = self.get_attention_mask(encoder_attention_mask) # if use bool mask # 1. Input @@ -309,7 +304,7 @@ def construct( height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( - hidden_states, encoder_hidden_states, timestep, batch_size, frame, dtype=dtype + hidden_states, encoder_hidden_states, timestep, batch_size, frame, dtype=self.dtype ) if get_sequence_parallel_state(): @@ -340,7 +335,7 @@ def construct( else: attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] - # if self.training and self.gradient_checkpointing: #TODO: training + # if self.training and self.gradient_checkpointing: #TODO: training hidden_states = block( hidden_states, @@ -372,7 +367,7 @@ def construct( return output - def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame, dtype=ms.float16): + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame, dtype): hidden_states = self.pos_embed(hidden_states.to(dtype)) # (b, t*h*w, d) added_cond_kwargs = {"resolution": None, "aspect_ratio": None} diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index 84c9a83abb..bc1f0db9cc 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -5,6 +5,7 @@ import numpy as np from opensora.acceleration.communications import AllToAll_SBH from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info +from opensora.npu_config import npu_config import mindspore as ms from mindspore import Parameter, mint, nn, ops @@ -38,30 +39,20 @@ def construct(self, x: ms.Tensor): x, _, _ = self.layer_norm(x, self.gamma, self.beta) return x -def get_attention_mask(attention_mask, repeat_num, attention_mode="xformers"): - if attention_mask is not None: - if attention_mode != "math": - attention_mask = attention_mask.to(ms.bool_) - else: - attention_mask = attention_mask.repeat_interleave(repeat_num, dim=-2) - return attention_mask - class Attention(Attention_): def __init__( self, interpolation_scale_thw, sparse1d, sparse_n, - sparse_group, is_cross_attn, attention_mode="xformers", **kwags + sparse_group, is_cross_attn, **kwags ): - FA_dtype = kwags.pop("FA_dtype", ms.bfloat16) + processor = OpenSoraAttnProcessor2_0( interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, sparse_group=sparse_group, is_cross_attn=is_cross_attn, - attention_mode=attention_mode, - FA_dtype=FA_dtype, dim_head=kwags["dim_head"] + dim_head=kwags["dim_head"] ) - kwags["processor"] = processor - super().__init__(**kwags) - if attention_mode == "xformers": + super().__init__(processor=processor, **kwags) + if npu_config.enable_FA: self.set_use_memory_efficient_attention_xformers(True) self.processor = processor @@ -76,7 +67,7 @@ def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_n else: pad_len = sparse_n * sparse_n - l % (sparse_n * sparse_n) - attention_mask_sparse = mint.nn.functional.pad(attention_mask, (0, pad_len, 0, 0), mode="constant", value=-9980.0) + attention_mask_sparse = mint.nn.functional.pad(attention_mask, (0, pad_len, 0, 0), mode="constant", value=0) # 0 for discard b = attention_mask_sparse.shape[0] k = sparse_n m = sparse_n @@ -85,30 +76,25 @@ def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_n # b 1 1 (n m k) -> (m b) 1 1 (n k) attention_mask_sparse_1d_group = attention_mask_sparse.reshape(b, 1, 1, -1, m, k).permute(4, 0, 1, 2, 3, 5).reshape(m*b, 1, 1, -1) encoder_attention_mask_sparse = encoder_attention_mask.tile((sparse_n, 1, 1, 1)) - # if npu_config is not None: - attention_mask_sparse_1d = get_attention_mask( + + # get attention mask dtype, and shape + attention_mask_sparse_1d = npu_config.get_attention_mask( attention_mask_sparse_1d, attention_mask_sparse_1d.shape[-1] ) - attention_mask_sparse_1d_group = get_attention_mask( + attention_mask_sparse_1d_group = npu_config.get_attention_mask( attention_mask_sparse_1d_group, attention_mask_sparse_1d_group.shape[-1] ) - encoder_attention_mask_sparse_1d = get_attention_mask( + encoder_attention_mask_sparse_1d = npu_config.get_attention_mask( encoder_attention_mask_sparse, attention_mask_sparse_1d.shape[-1] ) encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d - # else: - # attention_mask_sparse_1d = attention_mask_sparse_1d.repeat_interleave(head_num, dim=1) - # attention_mask_sparse_1d_group = attention_mask_sparse_1d_group.repeat_interleave(head_num, dim=1) - - # encoder_attention_mask_sparse_1d = encoder_attention_mask_sparse.repeat_interleave(head_num, dim=1) - # encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d return { False: (attention_mask_sparse_1d, encoder_attention_mask_sparse_1d), True: (attention_mask_sparse_1d_group, encoder_attention_mask_sparse_1d_group) } - + # NO USE YET def prepare_attention_mask( self, attention_mask: ms.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> ms.Tensor: @@ -137,7 +123,7 @@ def prepare_attention_mask( current_length: int = attention_mask.shape[-1] if current_length != target_length: - attention_mask = mint.nn.functional.pad(attention_mask, (0, target_length), mode="constant", value=0.0) + attention_mask = mint.nn.functional.pad(attention_mask, (0, target_length), mode="constant", value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: @@ -157,30 +143,23 @@ class OpenSoraAttnProcessor2_0: def __init__(self, interpolation_scale_thw=(1, 1, 1), sparse1d=False, sparse_n=2, sparse_group=False, is_cross_attn=True, - FA_dtype=ms.bfloat16, dim_head=64, attention_mode = "xformers"): + dim_head=96): self.sparse1d = sparse1d self.sparse_n = sparse_n self.sparse_group = sparse_group self.is_cross_attn = is_cross_attn self.interpolation_scale_thw = interpolation_scale_thw - self.attention_mode = attention_mode self._init_rope(interpolation_scale_thw, dim_head=dim_head) - self.attention_mode = "xformers" #TBD - # Currently we only support setting attention_mode to `flash` or `math` - assert self.attention_mode in [ - "xformers", - "math", - ], f"Unsupported attention mode {self.attention_mode}. Currently we only support ['xformers', 'math']!" - self.enable_FA = self.attention_mode == "xformers" - self.FA_dtype = FA_dtype - assert self.FA_dtype in [ms.float16, ms.bfloat16], f"Unsupported flash-attention dtype: {self.FA_dtype}" - if self.enable_FA: - FLASH_IS_AVAILABLE = check_valid_flash_attention() - self.enable_FA = FLASH_IS_AVAILABLE and self.enable_FA + # if npu_config.enable_FA: + # FLASH_IS_AVAILABLE = check_valid_flash_attention() + # npu_config.enable_FA = FLASH_IS_AVAILABLE and npu_config.enable_FA + # if npu_config.enable_FA: + # npu_config.FA_dtype = FA_dtype + # assert FA_dtype in [ms.float16, ms.bfloat16], f"Unsupported flash-attention dtype: {FA_dtype}" + # self.fa_mask_dtype = choose_flash_attention_dtype() - self.fa_mask_dtype = choose_flash_attention_dtype() if get_sequence_parallel_state(): self.sp_size = hccl_info.world_size self.alltoall_sbh_q = AllToAll_SBH(scatter_dim=1, gather_dim=0) @@ -196,104 +175,7 @@ def __init__(self, interpolation_scale_thw=(1, 1, 1), def _init_rope(self, interpolation_scale_thw, dim_head): self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw, dim_head=dim_head) - self.position_getter = PositionGetter3D() - - def run_ms_flash_attention( - self, - attn, - query, - key, - value, - attention_mask, - input_layout="BSH", - attention_dropout: float = 0.0, - ): - # Memory efficient attention on mindspore uses flash attention under the hoods. - # Flash attention implementation is called `FlashAttentionScore` - # which is an experimental api with the following limitations: - # 1. Sequence length of query must be divisible by 16 and in range of [1, 32768]. - # 2. Head dimensions must be one of [64, 80, 96, 120, 128, 256]. - # 3. The input dtype must be float16 or bfloat16. - # Sequence length of query must be checked in runtime. - if input_layout not in ["BSH", "BNSD"]: - raise ValueError(f"input_layout must be in ['BSH', 'BNSD'], but get {input_layout}.") - Bs, query_tokens, _ = query.shape - assert query_tokens % 16 == 0, f"Sequence length of query must be divisible by 16, but got {query_tokens=}." - key_tokens = key.shape[1] - heads = attn.heads if not get_sequence_parallel_state() else attn.heads // hccl_info.world_size - query = query.view(Bs, query_tokens, heads, -1) - key = key.view(Bs, key_tokens, heads, -1) - value = value.view(Bs, key_tokens, heads, -1) - # Head dimension is checked in Attention.set_use_memory_efficient_attention_xformers. We maybe pad on head_dim. - if attn.head_dim_padding > 0: - query_padded = mint.nn.functional.pad(query, (0, attn.head_dim_padding), mode="constant", value=0.0) - key_padded = mint.nn.functional.pad(key, (0, attn.head_dim_padding), mode="constant", value=0.0) - value_padded = mint.nn.functional.pad(value, (0, attn.head_dim_padding), mode="constant", value=0.0) - else: - query_padded, key_padded, value_padded = query, key, value - flash_attn = ops.operations.nn_ops.FlashAttentionScore( - scale_value=attn.scale, head_num=heads, input_layout=input_layout, keep_prob=1 - attention_dropout - ) - if attention_mask is not None: - # flip mask, since ms FA treats 1 as discard, 0 as retain. - attention_mask = ~attention_mask if attention_mask.dtype == ms.bool_ else 1 - attention_mask - # (b, 1, 1, k_n) - > (b, 1, q_n, k_n), manual broadcast - if attention_mask.shape[-2] == 1: - attention_mask = mint.tile(attention_mask.bool(), (1, 1, query_tokens, 1)) - attention_mask = attention_mask.to(self.fa_mask_dtype) - - if input_layout == "BNSD": - # (b s n d) -> (b n s d) - query_padded = query_padded.swapaxes(1, 2) - key_padded = key_padded.swapaxes(1, 2) - value_padded = value_padded.swapaxes(1, 2) - elif input_layout == "BSH": - query_padded = query_padded.view(Bs, query_tokens, -1) - key_padded = key_padded.view(Bs, key_tokens, -1) - value_padded = value_padded.view(Bs, key_tokens, -1) - hidden_states_padded = flash_attn( - query_padded.to(self.FA_dtype), - key_padded.to(self.FA_dtype), - value_padded.to(self.FA_dtype), - None, - None, - None, - attention_mask, - )[3] - # If we did padding before calculate attention, undo it! - if attn.head_dim_padding > 0: - if input_layout == "BNSD": - hidden_states = hidden_states_padded[..., : attn.head_dim] - else: - hidden_states = hidden_states_padded.view(Bs, query_tokens, heads, -1)[..., : attn.head_dim] - hidden_states = hidden_states.view(Bs, query_tokens, -1) - else: - hidden_states = hidden_states_padded - if input_layout == "BNSD": - # b n s d -> b s n d - hidden_states = hidden_states.swapaxes(1, 2) - hidden_states = hidden_states.reshape(Bs, query_tokens, -1) - hidden_states = hidden_states.to(query.dtype) - return hidden_states - - def run_math_attention(self, attn, query, key, value, attention_mask): - _head_size = attn.heads if not get_sequence_parallel_state() else attn.heads // hccl_info.world_size - query = self._head_to_batch_dim(_head_size, query) - key = self._head_to_batch_dim(_head_size, key) - value = self._head_to_batch_dim(_head_size, value) - - if attention_mask is not None: - if attention_mask.ndim == 3: - attention_mask = attention_mask.unsqeeuze(1) - assert attention_mask.shape[1] == 1 - attention_mask = attention_mask.repeat_interleave(_head_size, 1) - attention_mask = attention_mask.reshape(-1, attention_mask.shape[-2], attention_mask.shape[-1]) - attention_mask = mint.zeros(attention_mask.shape).masked_fill(attention_mask.to(ms.bool_), -10000.0) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = mint.bmm(attention_probs, value) - hidden_states = self._batch_to_head_dim(_head_size, hidden_states) - return hidden_states + self.position_getter = PositionGetter3D() # TODO: need consider shapes for parallel seq and non-parallel cases def _sparse_1d(self, x, frame, height, width): @@ -359,7 +241,9 @@ def _sparse_1d_kv(self, x): require the shape of (ntokens x batch_size x dim) """ # s b d -> s (k b) d - x = x.repeat(self.sparse_n, axis = 1) + # x = repeat(x, 's b d -> s (k b) d', k = self.sparse_n) # original + # x = x.repeat(self.sparse_n, axis = 1) # WRONG!!! + x = x.tile((1, self.sparse_n, 1)) return x def __call__( @@ -459,25 +343,7 @@ def __call__( query = query.swapaxes(0, 1) # SBH to BSH key = key.swapaxes(0, 1) value = value.swapaxes(0, 1) - if self.attention_mode == "math": - # FIXME: shape error - hidden_states = self.run_math_attention(attn, query, key, value, attention_mask) - elif self.attention_mode == "xformers": - hidden_states = self.run_ms_flash_attention(attn, query, key, value, attention_mask) - # if npu_config is not None: - # hidden_states = npu_config.run_attention(query, key, value, attention_mask, "SBH", head_dim, FA_head_num) - # else: - # query = rearrange(query, 's b (h d) -> b h s d', h=FA_head_num) - # key = rearrange(key, 's b (h d) -> b h s d', h=FA_head_num) - # value = rearrange(value, 's b (h d) -> b h s d', h=FA_head_num) - # # 0, -10000 ->(bool) False, True ->(any) True ->(not) False - # # 0, 0 ->(bool) False, False ->(any) False ->(not) True - # # if attention_mask is None or not torch.any(attention_mask.bool()): # 0 mean visible - # # attention_mask = None - # # the output of sdp = (batch, num_heads, seq_len, head_dim) - # with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True): - # hidden_states = scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) # dropout_p=0.0, is_causal=False - # hidden_states = rearrange(hidden_states, 'b h s d -> s b (h d)', h=FA_head_num) + hidden_states = npu_config.run_attention(query, key, value, attention_mask, input_layout="BSH", head_dim=head_dim, head_num=FA_head_num) if self.sparse1d: hidden_states = hidden_states.swapaxes(0, 1) # BSH -> SBH @@ -523,11 +389,16 @@ def __init__( sparse1d: bool = False, sparse_n: int = 2, sparse_group: bool = False, - attention_mode: str = "xformers", FA_dtype=ms.bfloat16, ): super().__init__() - self.FA_dtype = FA_dtype + + if npu_config.enable_FA: + FLASH_IS_AVAILABLE = check_valid_flash_attention() + npu_config.enable_FA = FLASH_IS_AVAILABLE and npu_config.enable_FA + if npu_config.enable_FA: + npu_config.FA_dtype = FA_dtype + assert FA_dtype in [ms.float16, ms.bfloat16], f"Unsupported flash-attention dtype: {FA_dtype}" # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn @@ -547,8 +418,6 @@ def __init__( sparse_n=sparse_n, sparse_group=sparse_group, is_cross_attn=False, - attention_mode=attention_mode, - FA_dtype=self.FA_dtype, ) # 2. Cross-Attn @@ -568,9 +437,7 @@ def __init__( sparse_n=sparse_n, sparse_group=sparse_group, is_cross_attn=True, - attention_mode=attention_mode, - FA_dtype=self.FA_dtype, - ) # is self-attn if encoder_hidden_states is none + ) # 3. Feed-forward self.ff = FeedForward( diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index d08a0dfcba..fb12092e18 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -216,9 +216,6 @@ def prepare_pipeline(args): logger.warning( f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {args.num_frames}. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image)" ) - # dit_dtype = get_precision(args.precision) - # if dit_dtype == "fp16": # Attention processor cannot convert to fp16 - dit_dtype = None if args.version == 'v1_3': # TODO # if args.model_type == 'inpaint' or args.model_type == 'i2v': @@ -232,7 +229,6 @@ def prepare_pipeline(args): args.model_path, state_dict=state_dict, cache_dir=args.cache_dir, - # mindspore_dtype=dit_dtype, FA_dtype = FA_dtype, output_loading_info=True, ) @@ -565,7 +561,6 @@ def generate(step, data, ext, conditional_pixel_values_path=None, mask_type=None # else: for step, data in tqdm(enumerate(ds_iter), total=dataset_size): generate(step, data, ext) - break # TODO: debug use, delete later # Delete files that are no longer needed From c5110a13e4065d0c775ba817d5e93f0c48be63e7 Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Thu, 14 Nov 2024 14:49:22 +0800 Subject: [PATCH 006/133] update --- examples/opensora_pku/opensora/models/diffusion/common.py | 2 -- .../models/diffusion/opensora/modeling_opensora.py | 8 ++++---- .../opensora_pku/opensora/sample/pipeline_opensora.py | 8 ++++---- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/common.py b/examples/opensora_pku/opensora/models/diffusion/common.py index a5f4a47125..1e00e98fa8 100644 --- a/examples/opensora_pku/opensora/models/diffusion/common.py +++ b/examples/opensora_pku/opensora/models/diffusion/common.py @@ -51,7 +51,6 @@ def __call__(self, b, t, h, w): pos = list(itertools.product(z, y, x)) pos = ms.Tensor(pos) if get_sequence_parallel_state(): - # print('PositionGetter3D', PositionGetter3D) pos = pos.reshape(t * h * w, 3).swapaxes(0, 1).reshape(3, -1, 1).broadcast_to((3, -1, b)) else: pos = pos.reshape(t * h * w, 3).swapaxes(0, 1).reshape(3, 1, -1).broadcast_to((3, b, -1)) @@ -121,7 +120,6 @@ def construct(self, tokens, positions): cos_y, sin_y = self.get_cos_sin(max_poses[1] + 1, self.interpolation_scale_h) cos_x, sin_x = self.get_cos_sin(max_poses[2] + 1, self.interpolation_scale_w) # split features into three along the feature dimension, and apply rope1d on each half - # t, y, x = tokens.chunk(3, dim=-1) t, y, x = mint.chunk(tokens, 3, dim=-1) t = self.apply_rope1d(t, poses[0], cos_t.to(tokens.dtype), sin_t.to(tokens.dtype)) y = self.apply_rope1d(y, poses[1], cos_y.to(tokens.dtype), sin_y.to(tokens.dtype)) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index e8a5546c46..b0c02396e6 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -304,7 +304,7 @@ def construct( height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( - hidden_states, encoder_hidden_states, timestep, batch_size, frame, dtype=self.dtype + hidden_states, encoder_hidden_states, timestep, batch_size, frame ) if get_sequence_parallel_state(): @@ -367,12 +367,12 @@ def construct( return output - def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame, dtype): - hidden_states = self.pos_embed(hidden_states.to(dtype)) # (b, t*h*w, d) + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame): + hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # (b, t*h*w, d) added_cond_kwargs = {"resolution": None, "aspect_ratio": None} timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=dtype + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype ) # b 6d, b d encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1, l, d diff --git a/examples/opensora_pku/opensora/sample/pipeline_opensora.py b/examples/opensora_pku/opensora/sample/pipeline_opensora.py index 77f120cbdb..59a408b533 100644 --- a/examples/opensora_pku/opensora/sample/pipeline_opensora.py +++ b/examples/opensora_pku/opensora/sample/pipeline_opensora.py @@ -446,7 +446,7 @@ def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - return latents.to(dtype) + return latents def prepare_parallel_latent(self, video_states): sp_size = hccl_info.world_size @@ -567,7 +567,7 @@ def __call__( negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, - dtype=ms.float16, #self.transformer.dtype, + dtype=self.transformer.dtype, num_samples_per_prompt=num_samples_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, @@ -587,7 +587,7 @@ def __call__( negative_prompt_attention_mask_2, ) = self.encode_prompt( prompt=prompt, - dtype=ms.float16, #self.transformer.dtype, + dtype=self.transformer.dtype, num_samples_per_prompt=num_samples_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, @@ -790,7 +790,7 @@ def __call__( def decode_latents_per_sample(self, latents): print(f'before vae decode {latents.shape}', latents.max().item(), latents.min().item(), latents.mean().item(), latents.std().item()) video = self.vae.decode(latents).to(ms.float32) # (b t c h w) - print(f'after vae decode {latents.shape}', latents.max().item(), latents.min().item(), latents.mean().item(), latents.std().item()) + print(f'after vae decode {video.shape}', video.max().item(), video.min().item(), video.mean().item(), video.std().item()) video = ops.clip_by_value((video / 2.0 + 0.5), clip_value_min=0.0, clip_value_max=1.0).permute(0, 1, 3, 4, 2) return video # b t h w c From bde071851da862c5d34ba21c81956cc832239b09 Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Thu, 14 Nov 2024 15:44:35 +0800 Subject: [PATCH 007/133] training --- .../opensora_pku/opensora/dataset/__init__.py | 7 +- .../opensora/dataset/t2v_datasets.py | 11 +- .../diffusion/opensora/modeling_opensora.py | 114 ++++++++- .../diffusion/opensora/net_with_loss.py | 5 +- .../opensora_pku/opensora/train/commons.py | 2 +- .../opensora/train/train_t2v_diffusers.py | 238 +++++++++++++----- .../opensora/utils/dataset_utils.py | 3 + .../opensora/utils/sample_utils.py | 4 +- 8 files changed, 304 insertions(+), 80 deletions(-) diff --git a/examples/opensora_pku/opensora/dataset/__init__.py b/examples/opensora_pku/opensora/dataset/__init__.py index 0123e71dec..1289ac8b90 100644 --- a/examples/opensora_pku/opensora/dataset/__init__.py +++ b/examples/opensora_pku/opensora/dataset/__init__.py @@ -44,7 +44,10 @@ def norm_func_albumentation(image, **kwargs): additional_targets=targets, ) - tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir) + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir) + if args.text_encoder_name_2 is not None: + tokenizer_2 = AutoTokenizer.from_pretrained(args.text_encoder_name_2, cache_dir=args.cache_dir) + if args.dataset == "t2v": return T2V_dataset( dataset_file, @@ -59,7 +62,7 @@ def norm_func_albumentation(image, **kwargs): max_width=args.max_width, drop_short_ratio=args.drop_short_ratio, dataloader_num_workers=args.dataloader_num_workers, - text_encoder_name=args.text_encoder_name, + text_encoder_name=args.text_encoder_name_1, # TODO: update with 2nd text encoder return_text_emb=args.text_embed_cache, transform=transform, temporal_sample=temporal_sample, diff --git a/examples/opensora_pku/opensora/dataset/t2v_datasets.py b/examples/opensora_pku/opensora/dataset/t2v_datasets.py index 8fcc57e6b5..2909d361b0 100644 --- a/examples/opensora_pku/opensora/dataset/t2v_datasets.py +++ b/examples/opensora_pku/opensora/dataset/t2v_datasets.py @@ -210,10 +210,11 @@ def get_video(self, idx): video = self.decord_read(video_path, predefine_num_frames=len(frame_indice)) # (T H W C) h, w = video.shape[1:3] - assert h / w <= 17 / 16 and h / w >= 8 / 16, ( - f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({video_path}) " - + f"found ratio is {round(h / w, 2)} with the shape of {video.shape}" - ) + # NOTE: not suitable for 1:1 training in v1.3 + # assert h / w <= 17 / 16 and h / w >= 8 / 16, ( + # f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({video_path}) " + # + f"found ratio is {round(h / w, 2)} with the shape of {video.shape}" + # ) input_videos = {"image": video[0]} input_videos.update(dict([(f"image{i}", video[i + 1]) for i in range(len(video) - 1)])) output_videos = self.transform(**input_videos) @@ -319,7 +320,7 @@ def define_frame_index(self, cap_list): continue height, width = i["resolution"]["height"], i["resolution"]["width"] aspect = self.max_height / self.max_width - hw_aspect_thr = 1.5 + hw_aspect_thr = 2.0 #NOTE: for 1:1 frame training is_pick = filter_resolution( height, width, diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index b0c02396e6..2d477f027b 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -66,6 +66,19 @@ def __init__( self.FA_dtype = FA_dtype #NEW self._init_patched_inputs() + if self.use_recompute: + num_no_recompute = self.config.num_no_recompute + num_blocks = len(self.transformer_blocks) + assert num_no_recompute >= 0, "Expect to have num_no_recompute as a positive integer." + assert ( + num_no_recompute <= num_blocks + ), "Expect to have num_no_recompute as an integer no greater than the number of blocks," + f"but got {num_no_recompute} and {num_blocks}." + logger.info(f"Excluding {num_no_recompute} blocks from the recomputation list.") + for bidx, block in enumerate(self.transformer_blocks): + if bidx < num_blocks - num_no_recompute: + self.recompute(block) + def _init_patched_inputs(self): self.config.sample_size = (self.config.sample_size_h, self.config.sample_size_w) @@ -120,6 +133,14 @@ def _init_patched_inputs(self): pad_mode="pad" ) + def recompute(self, b): + if not b._has_config_recompute: + b.recompute(parallel_optimizer_comm_recompute=True) + if isinstance(b, nn.CellList): + self.recompute(b[-1]) + elif ms.get_context("mode") == ms.GRAPH_MODE: + b.add_flags(output_no_recompute=True) + # rewrite class method to allow the state dict as input @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): @@ -248,6 +269,65 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return model + + @classmethod + def load_from_checkpoint(cls, model, ckpt_path): + if os.path.isdir(ckpt_path) or ckpt_path.endswith(".safetensors"): + return cls.load_from_safetensors(model, ckpt_path) + elif ckpt_path.endswith(".ckpt"): + return cls.load_from_ms_checkpoint(ckpt_path) + else: + raise ValueError("Only support safetensors pretrained ckpt or MindSpore pretrained ckpt!") + + @classmethod + def load_from_safetensors(cls, model, ckpt_path): + if os.path.isdir(ckpt_path): + ckpts = glob.glob(os.path.join(ckpt_path, "*.safetensors")) + n_ckpt = len(ckpts) + assert ( + n_ckpt == 1 + ), f"Expect to find only one safetenesors file under {ckpt_path}, but found {n_ckpt} .safetensors files." + model_file = ckpts[0] + pretrained_model_name_or_path = ckpt_path + elif ckpt_path.endswith(".safetensors"): + model_file = ckpt_path + pretrained_model_name_or_path = os.path.dirname(ckpt_path) + state_dict = load_state_dict(model_file, variant=None) + model._convert_deprecated_attention_blocks(state_dict) + + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ) + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + logger.info(loading_info) + return model + + @classmethod + def load_from_ms_checkpoint(self, model, ckpt_path): + sd = ms.load_checkpoint(ckpt_path) + # filter 'network.' prefix + rm_prefix = ["network."] + all_pnames = list(sd.keys()) + for pname in all_pnames: + for pre in rm_prefix: + if pname.startswith(pre): + new_pname = pname.replace(pre, "") + sd[new_pname] = sd.pop(pname) + + m, u = ms.load_param_into_net(model, sd) + print("net param not load: ", m, len(m)) + print("ckpt param not load: ", u, len(u)) + return model + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @@ -336,17 +416,29 @@ def construct( attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] # if self.training and self.gradient_checkpointing: #TODO: training - - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - frame=frame, - height=height, - width=width, - ) # BSH + if self.use_recompute and ms.get_context("mode") == ms.PYNATIVE: + block_args = { + "hidden_states": hidden_states, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "timestep": timestep, + "frame": frame, + "height": height, + "width": width, + } + hidden_states = ms.recompute(block, **block_args) #BSH + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + frame=frame, + height=height, + width=width, + ) # BSH if get_sequence_parallel_state(): diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py index 472274dce6..f3fd194fd5 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py @@ -20,7 +20,7 @@ class DiffusionWithLoss(nn.Cell): model (nn.Cell): A noise prediction model to denoise the encoded image latents. vae (nn.Cell): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. noise_scheduler: (object): A class for noise scheduler, such as DDPM scheduler - text_encoder (nn.Cell): A text encoding model which accepts token ids and returns text embeddings in shape (T, D). + text_encoder / text_encoder_2 (nn.Cell): A text encoding model which accepts token ids and returns text embeddings in shape (T, D). T is the number of tokens, and D is the embedding dimension. train_with_embed (bool): whether to train with embeddings (no need vae and text encoder to extract latent features and text embeddings) """ @@ -31,6 +31,7 @@ def __init__( noise_scheduler, vae: nn.Cell = None, text_encoder: nn.Cell = None, + text_encoder_2: nn.Cell = None, # not to use yet text_emb_cached: bool = True, video_emb_cached: bool = False, use_image_num: int = 0, @@ -125,7 +126,7 @@ def construct( attention_mask: ms.Tensor, text_tokens: ms.Tensor, encoder_attention_mask: ms.Tensor = None, - ): + ): # TODO: in the future add 2nd text encoder and tokens """ Video diffusion model forward and loss computation for training diff --git a/examples/opensora_pku/opensora/train/commons.py b/examples/opensora_pku/opensora/train/commons.py index 17650e2051..ef18fc4a1a 100644 --- a/examples/opensora_pku/opensora/train/commons.py +++ b/examples/opensora_pku/opensora/train/commons.py @@ -112,7 +112,7 @@ def parse_train_args(parser): "--end_learning_rate", default=1e-7, type=float, help="The end learning rate for the optimizer." ) parser.add_argument("--lr_decay_steps", default=0, type=int, help="lr decay steps.") - parser.add_argument("--lr_scheduler", default="cosine_decay", type=str, help="scheduler.") + parser.add_argument("--lr_scheduler", default="constant", type=str, help="scheduler.") parser.add_argument( "--scale_lr", default=False, diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index af072cb4ef..b9fd285d42 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -17,8 +17,7 @@ from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info from opensora.dataset import getdataset from opensora.dataset.loader import create_dataloader -from opensora.models import CausalVAEModelWrapper -from opensora.models.causalvideovae import ae_channel_config, ae_stride_config +from opensora.models.causalvideovae import ae_channel_config, ae_stride_config, ae_wrapper from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate from opensora.models.diffusion import Diffusion_models from opensora.models.diffusion.opensora.modules import Attention, LayerNorm @@ -30,16 +29,20 @@ from opensora.utils.message_utils import print_banner from opensora.utils.ms_utils import init_env from opensora.utils.utils import get_precision +from opensora.models.diffusion.common import PatchEmbed2D from mindone.diffusers.models.activations import SiLU -from mindone.diffusers.schedulers import DDPMScheduler as DDPMScheduler_diffusers +from mindone.diffusers.schedulers import ( + DDIMScheduler, DDPMScheduler, PNDMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, + FlowMatchEulerDiscreteScheduler,#CogVideoXDDIMScheduler, +) from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallbackEpoch, StopAtStepCallback from mindone.trainers.checkpoint import resume_train_network from mindone.trainers.lr_schedule import create_scheduler from mindone.trainers.optim import create_optimizer from mindone.trainers.train_step import TrainOneStepWrapper from mindone.trainers.zero import prepare_train_network -from mindone.transformers import MT5EncoderModel +from mindone.transformers import T5EncoderModel, MT5EncoderModel, CLIPTextModelWithProjection from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import str2bool from mindone.utils.logger import set_logger @@ -52,6 +55,9 @@ class DDPMScheduler(DDPMScheduler_diffusers): pass +################################################################################# +# Training Loop # +################################################################################# def set_all_reduce_fusion( params, @@ -91,6 +97,9 @@ def main(args): comm_fusion=args.comm_fusion, ) set_logger(name="", output_dir=args.output_dir, rank=rank_id, log_level=eval(args.log_level)) + + # 2. Init and load models + ## Load VAE train_with_vae_latent = args.vae_latent_folder is not None and len(args.vae_latent_folder) > 0 if train_with_vae_latent: assert os.path.exists( @@ -100,16 +109,22 @@ def main(args): vae = None else: print_banner("vae init") - vae = CausalVAEModelWrapper(args.ae_path, cache_dir=args.cache_dir, use_safetensors=True) - vae_dtype = get_precision(args.vae_precision) + kwarg = { + "state_dict": state_dict, + "use_safetensors": True, + "dtype": vae_dtype, + } + vae = ae_wrapper[args.ae](args.ae_path, **kwarg) + # vae.vae_scale_factor = ae_stride_config[args.ae] + if vae_dtype == ms.float16: custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] else: custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] + logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}") vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells) - logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells {custom_fp32_cells}") - + vae.set_train(False) for param in vae.get_parameters(): # freeze vae param.requires_grad = False @@ -135,47 +150,32 @@ def main(args): assert ( args.max_height % ae_stride_h == 0 ), f"Height must be divisible by ae_stride_h, but found Height ({args.max_height}), ae_stride_h ({ae_stride_h})." + assert (args.num_frames - 1) % ae_stride_t == 0, f"(Frames - 1) must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." assert ( args.max_width % ae_stride_h == 0 ), f"Width size must be divisible by ae_stride_h, but found Width ({args.max_width}), ae_stride_h ({ae_stride_h})." args.stride_t = ae_stride_t * patch_size_t args.stride = ae_stride_h * patch_size_h - latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w) - vae.latent_size = latent_size + vae.latent_size = latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w) + args.latent_size_t = latent_size_t = (args.num_frames - 1) // ae_stride_t + 1 - if args.num_frames % 2 == 1: - args.latent_size_t = latent_size_t = (args.num_frames - 1) // ae_stride_t + 1 - else: - latent_size_t = args.num_frames // ae_stride_t - FA_dtype = get_precision(args.precision) if get_precision(args.precision) != ms.float32 else ms.bfloat16 + + ## Load diffusion transformer print_banner("Transformer model init") + FA_dtype = get_precision(args.precision) if get_precision(args.precision) != ms.float32 else ms.bfloat16 model = Diffusion_models[args.model]( in_channels=ae_channel_config[args.ae], out_channels=ae_channel_config[args.ae], - attention_bias=True, - sample_size=latent_size, + sample_size_h=latent_size, + sample_size_w=latent_size, sample_size_t=latent_size_t, - num_vector_embeds=None, - activation_fn="gelu-approximate", - num_embeds_ada_norm=1000, - use_linear_projection=False, - only_cross_attention=False, - double_self_attention=False, - upcast_attention=False, - # norm_type="ada_norm_single", - norm_elementwise_affine=False, - norm_eps=1e-6, - attention_type="default", - attention_mode=args.attention_mode, interpolation_scale_h=args.interpolation_scale_h, interpolation_scale_w=args.interpolation_scale_w, interpolation_scale_t=args.interpolation_scale_t, - downsampler=args.downsampler, - # compress_kv_factor=args.compress_kv_factor, - use_rope=args.use_rope, - # model_max_length=args.model_max_length, - use_stable_fp32=args.enable_stable_fp32, + sparse1d=args.sparse1d, + sparse_n=args.sparse_n, + skip_connection=args.skip_connection, use_recompute=args.gradient_checkpointing, num_no_recompute=args.num_no_recompute, FA_dtype=FA_dtype, @@ -188,13 +188,12 @@ def main(args): model_dtype = get_precision(args.precision) if not args.global_bf16: if model_dtype == ms.float16: - custom_fp32_cells = [LayerNorm, nn.SiLU, SiLU, nn.GELU] - if args.attention_mode == "math": - custom_fp32_cells += [Attention] + custom_fp32_cells = [LayerNorm, Attention, PatchEmbed2D, nn.SiLU, SiLU, nn.GELU] else: custom_fp32_cells = [ nn.MaxPool2d, nn.MaxPool3d, + PatchEmbed2D, LayerNorm, nn.SiLU, SiLU, @@ -222,22 +221,61 @@ def main(args): logger.info("Use random initialization for transformer") model.set_train(True) + ## Load text encoder if not args.text_embed_cache: print_banner("text encoder init") text_encoder_dtype = get_precision(args.text_encoder_precision) - text_encoder, loading_info = MT5EncoderModel.from_pretrained( - args.text_encoder_name, - cache_dir=args.cache_dir, - output_loading_info=True, - mindspore_dtype=text_encoder_dtype, - use_safetensors=True, - ) - logger.info(loading_info) + if 'mt5' in args.text_encoder_name_1: + text_encoder_1, loading_info = MT5EncoderModel.from_pretrained( + args.text_encoder_name_1, + cache_dir=args.cache_dir, + output_loading_info=True, + mindspore_dtype=text_encoder_dtype, + use_safetensors=True + ) + loading_info.pop("unexpected_keys") # decoder weights are ignored + logger.info(f"Loaded MT5 Encoder: {loading_info}") + text_encoder_1 = text_encoder_1.set_train(False) + else: + text_encoder_1 = T5EncoderModel.from_pretrained( + args.text_encoder_name_1, cache_dir=args.cache_dir, + mindspore_dtype=text_encoder_dtype + ).set_train(False) + text_encoder_2 = None + if args.text_encoder_name_2 is not None: + text_encoder_2, loading_info = CLIPTextModelWithProjection.from_pretrained( + args.text_encoder_name_2, + cache_dir=args.cache_dir, + mindspore_dtype=text_encoder_dtype, + output_loading_info=True, + use_safetensors=True, + ) + loading_info.pop("unexpected_keys") # only load text model, ignore vision model + # loading_info.pop("mising_keys") # Note: missed keys when loading open-clip models + logger.info(f"Loaded CLIP Encoder: {loading_info}") + text_encoder_2 = text_encoder_2.set_train(False) else: - text_encoder = None + text_encoder_1 = None + text_encoder_2 = None text_encoder_dtype = None - noise_scheduler = DDPMScheduler() + kwargs = dict( + prediction_type=args.prediction_type, + rescale_betas_zero_snr=args.rescale_betas_zero_snr + ) + if args.cogvideox_scheduler: + noise_scheduler = CogVideoXDDIMScheduler(**kwargs) + elif args.v1_5_scheduler: + kwargs['beta_start'] = 0.00085 + kwargs['beta_end'] = 0.0120 + kwargs['beta_schedule'] = "scaled_linear" + noise_scheduler = DDPMScheduler(**kwargs) + elif args.rf_scheduler: + noise_scheduler = FlowMatchEulerDiscreteScheduler() + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + else: + noise_scheduler = DDPMScheduler(**kwargs) + # Get the target for loss depending on the prediction type if args.prediction_type is not None: # set prediction_type of scheduler if defined @@ -251,11 +289,12 @@ def main(args): logger.info("Training on image datasets only.") else: logger.info("Training on video datasets only.") + latent_diffusion_with_loss = DiffusionWithLoss( model, noise_scheduler, vae=vae, - text_encoder=text_encoder, + text_encoder=text_encoder_1, text_emb_cached=args.text_embed_cache, video_emb_cached=False, use_image_num=args.use_image_num, @@ -264,22 +303,34 @@ def main(args): snr_gamma=args.snr_gamma, ) latent_diffusion_eval, metrics, eval_indexes = None, None, None + # 3. create dataset # TODO: replace it with new dataset assert args.dataset == "t2v", "Support t2v dataset only." print_banner("Training dataset Loading...") + # Setup data: + # TODO: to use in v1.3 + if args.trained_data_global_step is not None: + initial_global_step_for_sampler = args.trained_data_global_step + else: + initial_global_step_for_sampler = 0 + if args.max_hxw is not None and args.min_hxw is None: + args.min_hxw = args.max_hxw // 4 + train_dataset = getdataset(args, dataset_file=args.data) sampler = ( LengthGroupedBatchSampler( args.train_batch_size, world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), lengths=train_dataset.lengths, - group_frame=args.group_frame, - group_resolution=args.group_resolution, + group_frame=args.group_frame, #v1.2 + group_resolution=args.group_resolution, #v1.2 + initial_global_step_for_sampler = initial_global_step_for_sampler, #TODO: use in v1.3 + group_data=args.group_data #TODO: use in v1.3 ) - if (args.group_frame or args.group_resolution) - else None + if (args.group_frame or args.group_resolution) #v1.2 + else None #v1.2 ) collate_fn = Collate( args.train_batch_size, @@ -365,7 +416,7 @@ def main(args): model, noise_scheduler, vae=vae, - text_encoder=text_encoder, + text_encoder=text_encoder_1, text_emb_cached=args.text_embed_cache, video_emb_cached=False, use_image_num=args.use_image_num, @@ -623,12 +674,14 @@ def main(args): callback.append(rec_cb) if args.profile: callback.append(ProfilerCallbackEpoch(2, 2, "./profile_data")) + # Train! assert ( args.train_sp_batch_size == 1 ), "Do not support train_sp_batch_size other than 1. Please set `--train_sp_batch_size 1`" total_batch_size = args.train_batch_size * device_num * args.gradient_accumulation_steps total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size + # 5. log and save config if rank_id == 0: if vae is not None: @@ -708,6 +761,77 @@ def main(args): def parse_t2v_train_args(parser): + ######## TODO: NEW in v1.3 , but may not use ### + # dataset & dataloader + parser.add_argument("--max_hxw", type=int, default=None) + parser.add_argument("--min_hxw", type=int, default=None) + parser.add_argument("--ood_img_ratio", type=float, default=0.0) + parser.add_argument("--group_data", action="store_true") + parser.add_argument("--hw_stride", type=int, default=32) + parser.add_argument("--force_resolution", action="store_true") + parser.add_argument("--trained_data_global_step", type=int, default=None) + parser.add_argument("--use_decord", action="store_true") + + # text encoder & vae & diffusion model + parser.add_argument('--vae_fp32', action='store_true') + parser.add_argument('--extra_save_mem', action='store_true') + parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl') + parser.add_argument("--text_encoder_name_2", type=str, default=None) + parser.add_argument('--sparse1d', action='store_true') + parser.add_argument('--sparse_n', type=int, default=2) + parser.add_argument('--skip_connection', action='store_true') + parser.add_argument('--cogvideox_scheduler', action='store_true') + parser.add_argument('--v1_5_scheduler', action='store_true') + parser.add_argument('--rf_scheduler', action='store_true') + parser.add_argument("--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]) + parser.add_argument("--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") + parser.add_argument("--mode_scale", type=float, default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.") + + # diffusion setting + parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.") + parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.") + parser.add_argument('--rescale_betas_zero_snr', action='store_true') + + # validation & logs + parser.add_argument("--enable_profiling", action="store_true") + parser.add_argument("--num_sampling_steps", type=int, default=20) + parser.add_argument('--guidance_scale', type=float, default=4.5) + parser.add_argument("--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store.")) + + # optimizer & scheduler + parser.add_argument("--optimizer", type=str, default="adamW", help='The optimizer type to use. Choose between ["AdamW", "prodigy"]') + parser.add_argument("--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers.") + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-02, help="Weight decay to use for unet params") + parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder") + parser.add_argument("--adam_epsilon", type=float, default=1e-15, help="Epsilon value for the Adam optimizer and Prodigy optimizers.") + parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True, help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW") + parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. Ignored if optimizer is adamW") + parser.add_argument("--prodigy_beta3", type=float, default=None, + help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--allow_tf32", action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + ######################## + parser.add_argument("--output_dir", default="outputs/", help="The directory where training results are saved.") parser.add_argument("--dataset", type=str, default="t2v") parser.add_argument( @@ -736,7 +860,7 @@ def parse_t2v_train_args(parser): help="Whether to use T5 embedding cache. Must be provided in image/video_data.", ) parser.add_argument("--vae_latent_folder", default=None, type=str, help="root dir for the vae latent data") - parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="OpenSoraT2V-ROPE-L/122") + parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="OpenSoraT2V_v1_3-2B/122") parser.add_argument("--interpolation_scale_h", type=float, default=1.0) parser.add_argument("--interpolation_scale_w", type=float, default=1.0) parser.add_argument("--interpolation_scale_t", type=float, default=1.0) @@ -765,8 +889,8 @@ def parse_t2v_train_args(parser): parser.add_argument("--enable_tiling", action="store_true") parser.add_argument("--attention_mode", type=str, choices=["xformers", "math", "flash"], default="xformers") - parser.add_argument("--text_encoder_name", type=str, default="DeepFloyd/t5-v1_1-xxl") - parser.add_argument("--model_max_length", type=int, default=300) + # parser.add_argument("--text_encoder_name", type=str, default="DeepFloyd/t5-v1_1-xxl") + parser.add_argument("--model_max_length", type=int, default=512) parser.add_argument("--multi_scale", action="store_true") # parser.add_argument("--enable_tracker", action="store_true") @@ -781,7 +905,7 @@ def parse_t2v_train_args(parser): parser.add_argument( "--checkpointing_steps", type=int, - default=None, + default=500, help=( "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" diff --git a/examples/opensora_pku/opensora/utils/dataset_utils.py b/examples/opensora_pku/opensora/utils/dataset_utils.py index 5f96d12681..be4c472977 100644 --- a/examples/opensora_pku/opensora/utils/dataset_utils.py +++ b/examples/opensora_pku/opensora/utils/dataset_utils.py @@ -398,8 +398,10 @@ def __init__( batch_size: int, world_size: int, lengths: Optional[List[int]] = None, + initial_global_step_for_sampler: int = 0, group_frame=False, group_resolution=False, + group_data = False, generator=None, ): if lengths is None: @@ -411,6 +413,7 @@ def __init__( self.lengths = lengths self.group_frame = group_frame self.group_resolution = group_resolution + self.group_data = group_data self.generator = generator self.remainder = len(self) * self.megabatch_size != len(self.lengths) diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index fb12092e18..31733a5419 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -32,8 +32,8 @@ from opensora.models.causalvideovae import ae_stride_config, ae_wrapper # from opensora.sample.caption_refiner import OpenSoraCaptionRefiner from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate -from examples.opensora_pku.opensora.models.diffusion.opensora.modeling_opensora import LayerNorm, OpenSoraT2V_v1_3 -from examples.opensora_pku.opensora.models.diffusion.opensora.modules import Attention +from examples.opensora_pku.opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V_v1_3 +from examples.opensora_pku.opensora.models.diffusion.opensora.modules import Attention, LayerNorm from opensora.sample.pipeline_opensora import OpenSoraPipeline from opensora.models.diffusion.common import PatchEmbed2D From fce6c11bf1227a34eceb6aad7f05947a966dd265 Mon Sep 17 00:00:00 2001 From: Yingshu CHEN Date: Thu, 14 Nov 2024 15:47:48 +0800 Subject: [PATCH 008/133] Create video_data_v1_2.txt --- examples/opensora_pku/scripts/train_data/video_data_v1_2.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 examples/opensora_pku/scripts/train_data/video_data_v1_2.txt diff --git a/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt b/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt new file mode 100644 index 0000000000..d2853a91d2 --- /dev/null +++ b/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt @@ -0,0 +1 @@ +/home_host/susan/workspace/datasets/Open-Sora-Plan-v1.2.0,/home_host/susan/workspace/datasets/Open-Sora-Plan-v1.2.0/mixkit_emb-len=512,/home_host/susan/workspace/datasets/Open-Sora-Plan-v1.2.0/v1.1.0_HQ_part1_Traffic_train.json From 8cdf20e155e9aa9a2a4aa860c4adf28e90800868 Mon Sep 17 00:00:00 2001 From: Yingshu CHEN Date: Thu, 14 Nov 2024 15:50:27 +0800 Subject: [PATCH 009/133] Create train_debug.sh --- .../single-device/train_debug.sh | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 examples/opensora_pku/scripts/text_condition/single-device/train_debug.sh diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_debug.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_debug.sh new file mode 100644 index 0000000000..6fb00f263f --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_debug.sh @@ -0,0 +1,72 @@ +''' +Training scheduler +We replaced the eps-pred loss with v-pred loss and enable ZeroSNR. For videos, we resample to 16 FPS for training. + +Stage 1: We initially initialized from the image weights of version 1.2.0 and trained images at a resolution of 1x320x320. The objective of this phase was to fine-tune the 3D dense attention model to a sparse attention model. The entire fine-tuning process involved approximately 100k steps, with a batch size of 1024 and a learning rate of 2e-5. The image data was primarily sourced from SAM in version 1.2.0. + +Stage 2: We trained the model jointly on images and videos, with a maximum resolution of 93x320x320. +The entire fine-tuning process involved approximately 300k steps, with a batch size of 1024 and a learning rate of 2e-5. +The image data was primarily sourced from SAM in version 1.2.0, while the video data consisted of the unfiltered Panda70m. +In fact, the model had nearly converged around 100k steps, and by 300k steps, there were no significant gains. +Subsequently, we performed data cleaning and caption rewriting, with further data analysis discussed at the end. + +Stage 3: We fine-tuned the model using our filtered Panda70m dataset, with a fixed resolution of 93x352x640. The entire fine-tuning process involved approximately 30k steps, with a batch size of 1024 and a learning rate of 1e-5. +''' + +# Stage 2: 93x320x320 +export DEVICE_ID=0 +NUM_FRAME=29 +HEIGHT=320 +WIDTH=320 +python opensora/train/train_t2v_diffusers.py \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 /home_host/susan/workspace/checkpoints/google/mt5-xxl \ + --cache_dir "./" \ + --dataset t2v \ + --data "scripts/train_data/video_data_v1_2.txt" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path /home_host/susan/workspace/checkpoints/LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --sample_rate 1 \ + --num_frames ${NUM_FRAME} \ + --max_height ${HEIGHT} \ + --max_width ${WIDTH} \ + --force_resolution \ + --interpolation_scale_t 1.0 \ + --interpolation_scale_h 1.0 \ + --interpolation_scale_w 1.0 \ + --gradient_checkpointing \ + --train_batch_size=1 \ + --dataloader_num_workers 1 \ + --gradient_accumulation_steps=1 \ + --max_train_steps 1000000 \ + --start_learning_rate=2e-5 \ + --lr_scheduler="constant" \ + --seed=10 \ + --lr_warmup_steps=500 \ + --precision="bf16" \ + --checkpointing_steps=1000 \ + --output_dir="./checkpoints/t2v-${NUM_FRAME}x${HEIGHT}x${WIDTH}/" \ + --model_max_length 512 \ + --use_image_num 0 \ + --cfg 0.1 \ + --snr_gamma 5.0 \ + --use_ema True\ + --ema_start_step 0 \ + --enable_tiling \ + --tile_overlap_factor 0.125 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --use_rope \ + --noise_offset 0.02 \ + --enable_stable_fp32 True \ + --ema_decay 0.999 \ + --speed_factor 1.0 \ + --drop_short_ratio 1.0 \ + --hw_stride 32 \ + --sparse1d \ + --sparse_n 4 \ + --train_fps 16 \ + --trained_data_global_step 0 \ + --group_data \ + --prediction_type "v_prediction" \ + --mode 1 From 1a51b8d99d6fba44da1e26ae902338fea961dfbe Mon Sep 17 00:00:00 2001 From: Yingshu CHEN Date: Thu, 14 Nov 2024 15:51:17 +0800 Subject: [PATCH 010/133] Create train_debug.sh --- .../multi-devices/train_debug.sh | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh new file mode 100644 index 0000000000..02ac55514f --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh @@ -0,0 +1,63 @@ +# Stage 2: 93x320x320 +NUM_FRAME=29 +WIDTH=320 +HEIGHT=320 +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +msrun --bind_core=True --worker_num=4 --local_worker_num=4 --master_port=6000 --log_dir="./checkpoints/t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}_zero2_mode1_npu4/parallel_logs" \ + opensora/train/train_t2v_diffusers.py \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 /home_host/susan/workspace/checkpoints/google/mt5-xxl \ + --cache_dir "./" \ + --dataset t2v \ + --data "scripts/train_data/video_data_v1_2.txt" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path /home_host/susan/workspace/checkpoints/LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --sample_rate 1 \ + --num_frames ${NUM_FRAME} \ + --max_height ${HEIGHT} \ + --max_width ${WIDTH} \ + --interpolation_scale_t 1.0 \ + --interpolation_scale_h 1.0 \ + --interpolation_scale_w 1.0 \ + --gradient_checkpointing \ + --train_batch_size=1 \ + --dataloader_num_workers 4 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --start_learning_rate=2e-5 \ + --lr_scheduler="constant" \ + --seed=10 \ + --lr_warmup_steps=500 \ + --precision="bf16" \ + --checkpointing_steps=1000 \ + --output_dir="./checkpoints/t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}_zero2_mode1_npu4/" \ + --model_max_length 512 \ + --use_image_num 0 \ + --cfg 0.1 \ + --snr_gamma 5.0 \ + --use_ema True\ + --ema_start_step 0 \ + --enable_tiling \ + --tile_overlap_factor 0.125 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --noise_offset 0.02 \ + --enable_stable_fp32 True\ + --ema_decay 0.999 \ + --speed_factor 1.0 \ + --drop_short_ratio 1.0 \ + --use_parallel True \ + --parallel_mode "zero" \ + --zero_stage 2 \ + --max_device_memory "58GB" \ + --jit_syntax_level "lax" \ + --dataset_sink_mode True \ + --num_no_recompute 18 \ + --prediction_type "v_prediction" \ + --hw_stride 32 \ + --sparse1d \ + --sparse_n 4 \ + --train_fps 16 \ + --trained_data_global_step 0 \ + --group_data \ + --mode 1 From bc42b7a61c3dd5de554af6eb353c5a5a9f8bb834 Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Thu, 14 Nov 2024 16:14:43 +0800 Subject: [PATCH 011/133] fix bug --- .../diffusion/opensora/modeling_opensora.py | 4 +++- .../opensora/train/train_t2v_diffusers.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 2d477f027b..6208c43748 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -21,6 +21,8 @@ from opensora.models.diffusion.common import PatchEmbed2D from opensora.npu_config import npu_config +logger = logging.getLogger(__name__) + class OpenSoraT2V_v1_3(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @@ -416,7 +418,7 @@ def construct( attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] # if self.training and self.gradient_checkpointing: #TODO: training - if self.use_recompute and ms.get_context("mode") == ms.PYNATIVE: + if self.use_recompute and ms.get_context("mode") == ms.PYNATIVE_MODE: block_args = { "hidden_states": hidden_states, "attention_mask": attention_mask, diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index b9fd285d42..99e3c94ff2 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -51,13 +51,10 @@ logger = logging.getLogger(__name__) -@ms.jit_class -class DDPMScheduler(DDPMScheduler_diffusers): - pass +# @ms.jit_class +# class DDPMScheduler(DDPMScheduler_diffusers): +# pass -################################################################################# -# Training Loop # -################################################################################# def set_all_reduce_fusion( params, @@ -76,6 +73,10 @@ def set_all_reduce_fusion( ms.set_auto_parallel_context(all_reduce_fusion_config=split_list) +################################################################################# +# Training Loop # +################################################################################# + def main(args): # 1. init save_src_strategy = args.use_parallel and args.parallel_mode != "data" @@ -111,7 +112,7 @@ def main(args): print_banner("vae init") vae_dtype = get_precision(args.vae_precision) kwarg = { - "state_dict": state_dict, + "state_dict": None, "use_safetensors": True, "dtype": vae_dtype, } From 562fefa311db01074546530c88ed248596e994c5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 19 Nov 2024 16:57:10 +0800 Subject: [PATCH 012/133] fix bug for recompute and sp --- .../diffusion/opensora/modeling_opensora.py | 253 ++++++++++-------- .../models/diffusion/opensora/modules.py | 186 +++++++------ 2 files changed, 237 insertions(+), 202 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 6208c43748..b91a089da7 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -5,6 +5,9 @@ from typing import Any, Dict, Optional from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info +from opensora.models.diffusion.common import PatchEmbed2D +from opensora.models.diffusion.opensora.modules import Attention, BasicTransformerBlock, LayerNorm +from opensora.npu_config import npu_config from opensora.utils.utils import to_2tuple import mindspore as ms @@ -17,12 +20,9 @@ from mindone.diffusers.models.normalization import AdaLayerNormSingle from mindone.diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file -from opensora.models.diffusion.opensora.modules import BasicTransformerBlock, LayerNorm, Attention -from opensora.models.diffusion.common import PatchEmbed2D -from opensora.npu_config import npu_config - logger = logging.getLogger(__name__) + class OpenSoraT2V_v1_3(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @@ -54,18 +54,17 @@ def __init__( interpolation_scale_t: float = 1.0, sparse1d: bool = False, sparse_n: int = 2, - - use_recompute=False, #NEW - FA_dtype=ms.bfloat16, #NEW - num_no_recompute: int = 0, #NEW + use_recompute=False, # NEW + FA_dtype=ms.bfloat16, # NEW + num_no_recompute: int = 0, # NEW ): super().__init__() # Set some common variables used across the board. self.out_channels = in_channels if out_channels is None else out_channels - self.config.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim #24*96=2304 - self.gradient_checkpointing = use_recompute #NEW - self.use_recompute = use_recompute #NEW - self.FA_dtype = FA_dtype #NEW + self.config.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim # 24*96=2304 + self.gradient_checkpointing = use_recompute # NEW + self.use_recompute = use_recompute # NEW + self.FA_dtype = FA_dtype # NEW self._init_patched_inputs() if self.use_recompute: @@ -82,22 +81,21 @@ def __init__( self.recompute(block) def _init_patched_inputs(self): - self.config.sample_size = (self.config.sample_size_h, self.config.sample_size_w) interpolation_scale_thw = ( - self.config.interpolation_scale_t, - self.config.interpolation_scale_h, - self.config.interpolation_scale_w - ) - + self.config.interpolation_scale_t, + self.config.interpolation_scale_h, + self.config.interpolation_scale_w, + ) + self.caption_projection = PixArtAlphaTextProjection( in_features=self.config.caption_channels, hidden_size=self.config.hidden_size ) - + self.pos_embed = PatchEmbed2D( - patch_size=self.config.patch_size, #2 - in_channels=self.config.in_channels, #8 - embed_dim=self.config.hidden_size, #2304 + patch_size=self.config.patch_size, # 2 + in_channels=self.config.in_channels, # 8 + embed_dim=self.config.hidden_size, # 2304 ) self.transformer_blocks = nn.CellList( [ @@ -114,11 +112,11 @@ def _init_patched_inputs(self): upcast_attention=self.config.upcast_attention, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, - interpolation_scale_thw=interpolation_scale_thw, - sparse1d=self.config.sparse1d if i > 1 and i < 30 else False, - sparse_n=self.config.sparse_n, - sparse_group=i % 2 == 1, - FA_dtype=self.FA_dtype + interpolation_scale_thw=interpolation_scale_thw, + sparse1d=self.config.sparse1d if i > 1 and i < 30 else False, + sparse_n=self.config.sparse_n, + sparse_group=i % 2 == 1, + FA_dtype=self.FA_dtype, ) for i in range(self.config.num_layers) ] @@ -126,13 +124,14 @@ def _init_patched_inputs(self): self.norm_out = LayerNorm(self.config.hidden_size, elementwise_affine=False, eps=1e-6) self.scale_shift_table = ms.Parameter(ops.randn((2, self.config.hidden_size)) / self.config.hidden_size**0.5) self.proj_out = nn.Dense( - self.config.hidden_size, self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels + self.config.hidden_size, + self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels, ) self.adaln_single = AdaLayerNormSingle(self.config.hidden_size) self.max_pool3d = nn.MaxPool3d( - kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), + kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), - pad_mode="pad" + pad_mode="pad", ) def recompute(self, b): @@ -271,7 +270,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return model - @classmethod def load_from_checkpoint(cls, model, ckpt_path): if os.path.isdir(ckpt_path) or ckpt_path.endswith(".safetensors"): @@ -333,11 +331,11 @@ def load_from_ms_checkpoint(self, model, ckpt_path): def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - + def get_attention_mask(self, attention_mask): if attention_mask is not None: if not npu_config.enable_FA: - attention_mask = attention_mask.to(ms.bool_) # use bool for sdpa + attention_mask = attention_mask.to(ms.bool_) # use bool for sdpa return attention_mask def construct( @@ -348,7 +346,7 @@ def construct( attention_mask: Optional[ms.Tensor] = None, encoder_attention_mask: Optional[ms.Tensor] = None, return_dict: bool = True, - **kwargs, + **kwargs, ): batch_size, c, frame, h, w = hidden_states.shape # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. @@ -372,18 +370,22 @@ def construct( attention_mask = self.max_pool3d(attention_mask) # b 1 t h w -> (b 1) 1 (t h w) attention_mask = attention_mask.reshape(batch_size, 1, -1) - - attention_mask = self.get_attention_mask(attention_mask) # if use bool mask + + attention_mask = self.get_attention_mask(attention_mask) # if use bool mask # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # b, 1, l - encoder_attention_mask = self.get_attention_mask(encoder_attention_mask) # if use bool mask - + encoder_attention_mask = self.get_attention_mask(encoder_attention_mask) # if use bool mask # 1. Input - frame = ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t # patchfy - height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size + frame = ( + ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t + ) # patchfy + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, + ) hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( hidden_states, encoder_hidden_states, timestep, batch_size, frame @@ -396,7 +398,7 @@ def construct( # b s h -> s b h hidden_states = hidden_states.swapaxes(0, 1).contiguous() # b s h -> s b h - encoder_hidden_states = encoder_hidden_states.swapaxes(0,1).contiguous() + encoder_hidden_states = encoder_hidden_states.swapaxes(0, 1).contiguous() timestep = timestep.view(batch_size, 6, -1).swapaxes(0, 1).contiguous() sparse_mask = {} @@ -408,40 +410,29 @@ def construct( # else: head_num = None for sparse_n in [1, 4]: - sparse_mask[sparse_n] = Attention.prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num) - + sparse_mask[sparse_n] = Attention.prepare_sparse_mask( + attention_mask, encoder_attention_mask, sparse_n, head_num + ) + # 2. Blocks for i, block in enumerate(self.transformer_blocks): if i > 1 and i < 30: - attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][block.attn1.processor.sparse_group] + attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][ + block.attn1.processor.sparse_group + ] else: attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] - # if self.training and self.gradient_checkpointing: #TODO: training - if self.use_recompute and ms.get_context("mode") == ms.PYNATIVE_MODE: - block_args = { - "hidden_states": hidden_states, - "attention_mask": attention_mask, - "encoder_hidden_states": encoder_hidden_states, - "encoder_attention_mask": encoder_attention_mask, - "timestep": timestep, - "frame": frame, - "height": height, - "width": width, - } - hidden_states = ms.recompute(block, **block_args) #BSH - else: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - frame=frame, - height=height, - width=width, - ) # BSH - + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + frame=frame, + height=height, + width=width, + ) # BSH if get_sequence_parallel_state(): # To (b, t*h*w, h) or (b, t//sp*h*w, h) @@ -453,35 +444,33 @@ def construct( hidden_states=hidden_states, timestep=timestep, embedded_timestep=embedded_timestep, - num_frames=frame, + num_frames=frame, height=height, width=width, ) # b c t h w - return output def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, batch_size, frame): - hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # (b, t*h*w, d) - + hidden_states = self.pos_embed(hidden_states.to(self.dtype)) # (b, t*h*w, d) + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} timestep, embedded_timestep = self.adaln_single( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype ) # b 6d, b d encoder_hidden_states = self.caption_projection(encoder_hidden_states) # b, 1, l, d - assert encoder_hidden_states.shape[1] == 1, f'encoder_hidden_states.shape is {encoder_hidden_states}' + assert encoder_hidden_states.shape[1] == 1, f"encoder_hidden_states.shape is {encoder_hidden_states}" # b 1 l d -> (b 1) l d - encoder_hidden_states = encoder_hidden_states.reshape(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]) + encoder_hidden_states = encoder_hidden_states.reshape( + -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] + ) return hidden_states, encoder_hidden_states, timestep, embedded_timestep - - def _get_output_for_patched_inputs( - self, hidden_states, timestep, embedded_timestep, num_frames, height, width - ): + def _get_output_for_patched_inputs(self, hidden_states, timestep, embedded_timestep, num_frames, height, width): shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, axis=1) - hidden_states = self.norm_out(hidden_states) #BSH -> BSH + hidden_states = self.norm_out(hidden_states) # BSH -> BSH hidden_states = hidden_states.squeeze(1) if hidden_states.shape[1] == 1 else hidden_states # Modulation @@ -491,22 +480,41 @@ def _get_output_for_patched_inputs( # unpatchify hidden_states = hidden_states.reshape( - -1, num_frames, height, width, self.config.patch_size_t, self.config.patch_size, self.config.patch_size, self.out_channels + -1, + num_frames, + height, + width, + self.config.patch_size_t, + self.config.patch_size, + self.config.patch_size, + self.out_channels, ) # nthwopqc -> nctohpwq hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.reshape( - -1, self.out_channels, - num_frames * self.config.patch_size_t, height * self.config.patch_size, width * self.config.patch_size + -1, + self.out_channels, + num_frames * self.config.patch_size_t, + height * self.config.patch_size, + width * self.config.patch_size, ) return output + def OpenSoraT2V_v1_3_2B_122(**kwargs): - kwargs.pop('skip_connection', None) + kwargs.pop("skip_connection", None) return OpenSoraT2V_v1_3( - num_layers=32, attention_head_dim=96, num_attention_heads=24, patch_size_t=1, patch_size=2, - caption_channels=4096, cross_attention_dim=2304, activation_fn="gelu-approximate", **kwargs - ) + num_layers=32, + attention_head_dim=96, + num_attention_heads=24, + patch_size_t=1, + patch_size=2, + caption_channels=4096, + cross_attention_dim=2304, + activation_fn="gelu-approximate", + **kwargs, + ) + OpenSora_v1_3_models = { "OpenSoraT2V_v1_3-2B/122": OpenSoraT2V_v1_3_2B_122, # 2.7B @@ -516,24 +524,26 @@ def OpenSoraT2V_v1_3_2B_122(**kwargs): "OpenSoraT2V_v1_3-2B/122": OpenSoraT2V_v1_3, } -if __name__ == '__main__': +if __name__ == "__main__": from opensora.models.causalvideovae import ae_stride_config - args = type('args', (), - { - 'ae': "WFVAEModel_D8_4x8x8", - 'model_max_length': 300, - 'max_height': 256, - 'max_width': 512, - 'num_frames': 33, - 'compress_kv_factor': 1, - 'interpolation_scale_t': 1, - 'interpolation_scale_h': 1, - 'interpolation_scale_w': 1, - "sparse1d": True, - "sparse_n": 4, - "rank": 64, - } + args = type( + "args", + (), + { + "ae": "WFVAEModel_D8_4x8x8", + "model_max_length": 300, + "max_height": 256, + "max_width": 512, + "num_frames": 33, + "compress_kv_factor": 1, + "interpolation_scale_t": 1, + "interpolation_scale_h": 1, + "interpolation_scale_w": 1, + "sparse1d": True, + "sparse_n": 4, + "rank": 64, + }, ) b = 2 c = 8 @@ -544,11 +554,11 @@ def OpenSoraT2V_v1_3_2B_122(**kwargs): num_frames = (args.num_frames - 1) // ae_stride_t + 1 model = OpenSoraT2V_v1_3_2B_122( - in_channels=c, - out_channels=c, - sample_size_h=latent_size, - sample_size_w=latent_size, - sample_size_t=num_frames, + in_channels=c, + out_channels=c, + sample_size_h=latent_size, + sample_size_w=latent_size, + sample_size_t=num_frames, # activation_fn="gelu-approximate", attention_bias=True, double_self_attention=False, @@ -556,16 +566,17 @@ def OpenSoraT2V_v1_3_2B_122(**kwargs): norm_eps=1e-06, only_cross_attention=False, upcast_attention=False, - interpolation_scale_t=args.interpolation_scale_t, - interpolation_scale_h=args.interpolation_scale_h, - interpolation_scale_w=args.interpolation_scale_w, - sparse1d=args.sparse1d, - sparse_n=args.sparse_n + interpolation_scale_t=args.interpolation_scale_t, + interpolation_scale_h=args.interpolation_scale_h, + interpolation_scale_w=args.interpolation_scale_w, + sparse1d=args.sparse1d, + sparse_n=args.sparse_n, ) - + try: - path = "/home_host/susan/workspace/checkpoints/Open-Sora-Plan-v1.3.0/any93x640x640/diffusion_pytorch_model.safetensors" + path = "checkpoints/Open-Sora-Plan-v1.3.0/any93x640x640/diffusion_pytorch_model.safetensors" from safetensors.torch import load_file as safe_load + ckpt = safe_load(path, device="cpu") msg = model.load_state_dict(ckpt, strict=True) print(msg) @@ -576,9 +587,15 @@ def OpenSoraT2V_v1_3_2B_122(**kwargs): # print(model) # print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9} B") # import sys;sys.exit() - x = ops.randn(b, c, 1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w) + x = ops.randn( + b, c, 1 + (args.num_frames - 1) // ae_stride_t, args.max_height // ae_stride_h, args.max_width // ae_stride_w + ) cond = ops.randn(b, 1, args.model_max_length, cond_c) - attn_mask = ops.randint(0, 2, (b, 1+(args.num_frames-1)//ae_stride_t, args.max_height//ae_stride_h, args.max_width//ae_stride_w)) # B L or B 1+num_images L + attn_mask = ops.randint( + 0, + 2, + (b, 1 + (args.num_frames - 1) // ae_stride_t, args.max_height // ae_stride_h, args.max_width // ae_stride_w), + ) # B L or B 1+num_images L cond_mask = ops.randint(0, 2, (b, 1, args.model_max_length)) # B L or B 1+num_images L timestep = ops.randint(0, 1000, (b,)) model_kwargs = dict( diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index bc1f0db9cc..bb67c58360 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) + class LayerNorm(nn.Cell): def __init__(self, normalized_shape, eps=1e-5, elementwise_affine: bool = True, dtype=ms.float32): super().__init__() @@ -41,24 +42,22 @@ def construct(self, x: ms.Tensor): class Attention(Attention_): - def __init__( - self, interpolation_scale_thw, sparse1d, sparse_n, - sparse_group, is_cross_attn, **kwags - ): - + def __init__(self, interpolation_scale_thw, sparse1d, sparse_n, sparse_group, is_cross_attn, **kwags): processor = OpenSoraAttnProcessor2_0( - interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, - sparse_group=sparse_group, is_cross_attn=is_cross_attn, - dim_head=kwags["dim_head"] - ) + interpolation_scale_thw=interpolation_scale_thw, + sparse1d=sparse1d, + sparse_n=sparse_n, + sparse_group=sparse_group, + is_cross_attn=is_cross_attn, + dim_head=kwags["dim_head"], + ) super().__init__(processor=processor, **kwags) if npu_config.enable_FA: self.set_use_memory_efficient_attention_xformers(True) self.processor = processor - + @staticmethod def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num): - attention_mask = attention_mask.unsqueeze(1) encoder_attention_mask = encoder_attention_mask.unsqueeze(1) l = attention_mask.shape[-1] @@ -67,33 +66,40 @@ def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_n else: pad_len = sparse_n * sparse_n - l % (sparse_n * sparse_n) - attention_mask_sparse = mint.nn.functional.pad(attention_mask, (0, pad_len, 0, 0), mode="constant", value=0) # 0 for discard + attention_mask_sparse = mint.nn.functional.pad( + attention_mask, (0, pad_len, 0, 0), mode="constant", value=0 + ) # 0 for discard b = attention_mask_sparse.shape[0] k = sparse_n m = sparse_n # b 1 1 (g k) -> (k b) 1 1 g - attention_mask_sparse_1d = attention_mask_sparse.reshape(b, 1, 1, -1, k).permute(4, 0, 1, 2, 3).reshape(b*k, 1, 1, -1) + attention_mask_sparse_1d = ( + attention_mask_sparse.reshape(b, 1, 1, -1, k).permute(4, 0, 1, 2, 3).reshape(b * k, 1, 1, -1) + ) # b 1 1 (n m k) -> (m b) 1 1 (n k) - attention_mask_sparse_1d_group = attention_mask_sparse.reshape(b, 1, 1, -1, m, k).permute(4, 0, 1, 2, 3, 5).reshape(m*b, 1, 1, -1) + attention_mask_sparse_1d_group = ( + attention_mask_sparse.reshape(b, 1, 1, -1, m, k).permute(4, 0, 1, 2, 3, 5).reshape(m * b, 1, 1, -1) + ) encoder_attention_mask_sparse = encoder_attention_mask.tile((sparse_n, 1, 1, 1)) - - # get attention mask dtype, and shape + + # get attention mask dtype, and shape attention_mask_sparse_1d = npu_config.get_attention_mask( attention_mask_sparse_1d, attention_mask_sparse_1d.shape[-1] - ) + ) attention_mask_sparse_1d_group = npu_config.get_attention_mask( attention_mask_sparse_1d_group, attention_mask_sparse_1d_group.shape[-1] - ) - + ) + encoder_attention_mask_sparse_1d = npu_config.get_attention_mask( encoder_attention_mask_sparse, attention_mask_sparse_1d.shape[-1] - ) + ) encoder_attention_mask_sparse_1d_group = encoder_attention_mask_sparse_1d - + return { - False: (attention_mask_sparse_1d, encoder_attention_mask_sparse_1d), - True: (attention_mask_sparse_1d_group, encoder_attention_mask_sparse_1d_group) - } + False: (attention_mask_sparse_1d, encoder_attention_mask_sparse_1d), + True: (attention_mask_sparse_1d_group, encoder_attention_mask_sparse_1d_group), + } + # NO USE YET def prepare_attention_mask( self, attention_mask: ms.Tensor, target_length: int, batch_size: int, out_dim: int = 3 @@ -117,13 +123,13 @@ def prepare_attention_mask( head_size = self.heads if get_sequence_parallel_state(): head_size = head_size // hccl_info.world_size # e.g, 24 // 8 - + if attention_mask is None: # b 1 t*h*w in sa, b 1 l in ca return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: - attention_mask = mint.nn.functional.pad(attention_mask, (0, target_length), mode="constant", value=0.0) + attention_mask = mint.nn.functional.pad(attention_mask, (0, target_length), mode="constant", value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: @@ -141,15 +147,21 @@ class OpenSoraAttnProcessor2_0: Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ - def __init__(self, interpolation_scale_thw=(1, 1, 1), - sparse1d=False, sparse_n=2, sparse_group=False, is_cross_attn=True, - dim_head=96): + def __init__( + self, + interpolation_scale_thw=(1, 1, 1), + sparse1d=False, + sparse_n=2, + sparse_group=False, + is_cross_attn=True, + dim_head=96, + ): self.sparse1d = sparse1d self.sparse_n = sparse_n self.sparse_group = sparse_group self.is_cross_attn = is_cross_attn self.interpolation_scale_thw = interpolation_scale_thw - + self._init_rope(interpolation_scale_thw, dim_head=dim_head) # if npu_config.enable_FA: @@ -165,7 +177,7 @@ def __init__(self, interpolation_scale_thw=(1, 1, 1), self.alltoall_sbh_q = AllToAll_SBH(scatter_dim=1, gather_dim=0) self.alltoall_sbh_k = AllToAll_SBH(scatter_dim=1, gather_dim=0) self.alltoall_sbh_v = AllToAll_SBH(scatter_dim=1, gather_dim=0) - self.alltoall_sbh_out = AllToAll_SBH(scatter_dim=1, gather_dim=0) + self.alltoall_sbh_out = AllToAll_SBH(scatter_dim=0, gather_dim=1) else: self.sp_size = 1 self.alltoall_sbh_q = None @@ -175,28 +187,28 @@ def __init__(self, interpolation_scale_thw=(1, 1, 1), def _init_rope(self, interpolation_scale_thw, dim_head): self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw, dim_head=dim_head) - self.position_getter = PositionGetter3D() + self.position_getter = PositionGetter3D() # TODO: need consider shapes for parallel seq and non-parallel cases def _sparse_1d(self, x, frame, height, width): """ require the shape of (ntokens x batch_size x dim) - + Convert to sparse groups Input: x: shape in S,B,D Output: x: shape if sparse_group: (S//sparse_n, sparse_n*B, D), else: (S//sparse_n, sparse_n*B, D) - pad_len: 0 or padding + pad_len: 0 or padding """ l = x.shape[0] - assert l == frame*height*width + assert l == frame * height * width pad_len = 0 if l % (self.sparse_n * self.sparse_n) != 0: pad_len = self.sparse_n * self.sparse_n - l % (self.sparse_n * self.sparse_n) if pad_len != 0: x = mint.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_len), mode="constant", value=0.0) - + _, b, d = x.shape if not self.sparse_group: # (g k) b d -> g (k b) d @@ -204,12 +216,12 @@ def _sparse_1d(self, x, frame, height, width): x = x.reshape(-1, k, b, d).reshape(-1, k * b, d) else: # (n m k) b d -> (n k) (m b) d - m = self.sparse_n + m = self.sparse_n k = self.sparse_n - x = x.reshape(-1, m, k, b, d).permute(0, 2, 1, 3, 4).reshape(-1, m*b, d) - + x = x.reshape(-1, m, k, b, d).permute(0, 2, 1, 3, 4).reshape(-1, m * b, d) + return x, pad_len - + def _reverse_sparse_1d(self, x, frame, height, width, pad_len): """ require the shape of (ntokens x batch_size x dim) @@ -220,22 +232,22 @@ def _reverse_sparse_1d(self, x, frame, height, width, pad_len): Output: x: shape if sparse_group: (S*sparse_n, B//sparse_n, D), else: (S*sparse_n, B//sparse_n, D) """ - assert x.shape[0] == (frame*height*width+pad_len) // self.sparse_n + assert x.shape[0] == (frame * height * width + pad_len) // self.sparse_n g, _, d = x.shape if not self.sparse_group: # g (k b) d -> (g k) b d k = self.sparse_n - x = x.reshape(g, k, -1, d).reshape(g*k, -1, d) + x = x.reshape(g, k, -1, d).reshape(g * k, -1, d) else: # (n k) (m b) d -> (n m k) b d m = self.sparse_n k = self.sparse_n assert g % k == 0 n = g // k - x = x.reshape(n, k, m, -1, d).permute(0, 2, 1, 3, 4).reshape(n*m*k, -1, d) - x = x[:frame*height*width, :, :] + x = x.reshape(n, k, m, -1, d).permute(0, 2, 1, 3, 4).reshape(n * m * k, -1, d) + x = x[: frame * height * width, :, :] return x - + def _sparse_1d_kv(self, x): """ require the shape of (ntokens x batch_size x dim) @@ -245,22 +257,21 @@ def _sparse_1d_kv(self, x): # x = x.repeat(self.sparse_n, axis = 1) # WRONG!!! x = x.tile((1, self.sparse_n, 1)) return x - + def __call__( self, attn: Attention, - hidden_states: ms.Tensor, - encoder_hidden_states: Optional[ms.Tensor] = None, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, attention_mask: Optional[ms.Tensor] = None, temb: Optional[ms.Tensor] = None, - frame: int = 8, - height: int = 16, - width: int = 16, + frame: int = 8, + height: int = 16, + width: int = 16, *args, **kwargs, ) -> ms.Tensor: - - residual = hidden_states + residual = hidden_states if get_sequence_parallel_state(): sequence_length, batch_size, _ = ( @@ -269,7 +280,7 @@ def __call__( else: batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) #BSH + ) # BSH # print(f"hidden_states.shape {hidden_states.shape}") #BSH query = attn.to_q(hidden_states) @@ -285,7 +296,7 @@ def __call__( FA_head_num = attn.heads total_frame = frame - if get_sequence_parallel_state(): #TODO: to test + if get_sequence_parallel_state(): # TODO: to test sp_size = hccl_info.world_size FA_head_num = attn.heads // sp_size total_frame = frame * sp_size @@ -293,28 +304,28 @@ def __call__( query = self.alltoall_sbh_q(query.view(-1, attn.heads, head_dim)) key = self.alltoall_sbh_k(key.view(-1, attn.heads, head_dim)) value = self.alltoall_sbh_v(value.view(-1, attn.heads, head_dim)) - + # print(f'batch: {batch_size}, FA_head_num: {FA_head_num}, head_dim: {head_dim}, total_frame:{total_frame}') - query = query.view(-1, batch_size, FA_head_num, head_dim)# BUG? TODO: to test - key = key.view(-1, batch_size, FA_head_num, head_dim) #BUG ? + query = query.view(-1, batch_size, FA_head_num, head_dim) # BUG? TODO: to test + key = key.view(-1, batch_size, FA_head_num, head_dim) # BUG ? # print(f'q {query.shape}, k {key.shape}, v {value.shape}') if not self.is_cross_attn: - # require the shape of (ntokens x batch_size x nheads x dim) + # require the shape of (ntokens x batch_size x nheads x dim) pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width) # print(f'pos_thw {pos_thw}') query = self.rope(query, pos_thw) key = self.rope(key, pos_thw) - + query = query.view(-1, batch_size, FA_head_num * head_dim) key = key.view(-1, batch_size, FA_head_num * head_dim) value = value.view(-1, batch_size, FA_head_num * head_dim) else: # print(f'batch: {batch_size}, FA_head_num: {FA_head_num}, head_dim: {head_dim}, total_frame:{total_frame}') query = query.view(batch_size, -1, FA_head_num, head_dim) - key = key.view(batch_size, -1, FA_head_num, head_dim) + key = key.view(batch_size, -1, FA_head_num, head_dim) # (batch_size x ntokens x nheads x dim) - + # print(f'q {query.shape}, k {key.shape}, v {value.shape}') if not self.is_cross_attn: # require the shape of (batch_size x ntokens x nheads x dim) @@ -322,33 +333,34 @@ def __call__( # print(f'pos_thw {pos_thw}') query = self.rope(query, pos_thw) key = self.rope(key, pos_thw) - + query = query.view(batch_size, -1, FA_head_num * head_dim).swapaxes(0, 1) key = key.view(batch_size, -1, FA_head_num * head_dim).swapaxes(0, 1) value = value.swapaxes(0, 1) - - # print(f'q {query.shape}, k {key.shape}, v {value.shape}') #(SBH) + + # print(f'q {query.shape}, k {key.shape}, v {value.shape}') #(SBH) if self.sparse1d: query, pad_len = self._sparse_1d(query, total_frame, height, width) if self.is_cross_attn: - key = self._sparse_1d_kv(key) + key = self._sparse_1d_kv(key) value = self._sparse_1d_kv(value) else: key, pad_len = self._sparse_1d(key, total_frame, height, width) value, pad_len = self._sparse_1d(value, total_frame, height, width) - # print(f'q {query.shape}, k {key.shape}, v {value.shape}') query = query.swapaxes(0, 1) # SBH to BSH key = key.swapaxes(0, 1) value = value.swapaxes(0, 1) - hidden_states = npu_config.run_attention(query, key, value, attention_mask, input_layout="BSH", head_dim=head_dim, head_num=FA_head_num) + hidden_states = npu_config.run_attention( + query, key, value, attention_mask, input_layout="BSH", head_dim=head_dim, head_num=FA_head_num + ) if self.sparse1d: - hidden_states = hidden_states.swapaxes(0, 1) # BSH -> SBH + hidden_states = hidden_states.swapaxes(0, 1) # BSH -> SBH hidden_states = self._reverse_sparse_1d(hidden_states, total_frame, height, width, pad_len) - hidden_states = hidden_states.swapaxes(0, 1) # SBH -> BSH + hidden_states = hidden_states.swapaxes(0, 1) # SBH -> BSH # [s, b, h // sp * d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] if get_sequence_parallel_state(): @@ -366,6 +378,7 @@ def __call__( return hidden_states + class BasicTransformerBlock(nn.Cell): def __init__( self, @@ -385,7 +398,7 @@ def __init__( ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, - interpolation_scale_thw: Tuple[int] = (1, 1, 1), + interpolation_scale_thw: Tuple[int] = (1, 1, 1), sparse1d: bool = False, sparse_n: int = 2, sparse_group: bool = False, @@ -413,7 +426,7 @@ def __init__( cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, out_bias=attention_out_bias, - interpolation_scale_thw=interpolation_scale_thw, + interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, sparse_group=sparse_group, @@ -432,12 +445,12 @@ def __init__( bias=attention_bias, upcast_attention=upcast_attention, out_bias=attention_out_bias, - interpolation_scale_thw=interpolation_scale_thw, + interpolation_scale_thw=interpolation_scale_thw, sparse1d=sparse1d, sparse_n=sparse_n, sparse_group=sparse_group, is_cross_attn=True, - ) + ) # 3. Feed-forward self.ff = FeedForward( @@ -452,7 +465,6 @@ def __init__( # 4. Scale-shift. self.scale_shift_table = Parameter(ops.randn((6, dim)) / dim**0.5) - def construct( self, hidden_states: ms.Tensor, @@ -460,21 +472,21 @@ def construct( encoder_hidden_states: Optional[ms.Tensor] = None, encoder_attention_mask: Optional[ms.Tensor] = None, timestep: Optional[ms.Tensor] = None, - frame: int = None, - height: int = None, - width: int = None, + frame: int = None, + height: int = None, + width: int = None, ) -> ms.Tensor: - # 0. Self-Attention if get_sequence_parallel_state(): batch_size = hidden_states.shape[1] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( - self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1), 6, dim=0) + self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1), 6, dim=0 + ) else: batch_size = hidden_states.shape[0] shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1), 6, dim=1 - ) + ) norm_hidden_states = self.norm1(hidden_states) @@ -483,7 +495,10 @@ def construct( attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, - attention_mask=attention_mask, frame=frame, height=height, width=width, + attention_mask=attention_mask, + frame=frame, + height=height, + width=width, ) attn_output = gate_msa * attn_output @@ -498,7 +513,10 @@ def construct( attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, frame=frame, height=height, width=width, + attention_mask=encoder_attention_mask, + frame=frame, + height=height, + width=width, ) hidden_states = attn_output + hidden_states @@ -514,4 +532,4 @@ def construct( hidden_states = ff_output + hidden_states - return hidden_states \ No newline at end of file + return hidden_states From 53336e10b019e27663deba5efcbd6964c976bda4 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 19 Nov 2024 16:57:26 +0800 Subject: [PATCH 013/133] update training scripts --- .../multi-devices/train_image3d_480p_zero2.sh | 47 -------------- ...3d_nx480p_zero2.sh => train_t2v_stage1.sh} | 47 ++++++++------ ...3x720p_zero2_sp.sh => train_t2v_stage2.sh} | 58 ++++++++++------- .../multi-devices/train_t2v_stage3.sh | 65 +++++++++++++++++++ .../train_video3d_29x720p_zero2_sp.sh | 57 ---------------- .../train_video3d_29x720p_zero2_sp_val.sh | 61 ----------------- 6 files changed, 125 insertions(+), 210 deletions(-) delete mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/train_image3d_480p_zero2.sh rename examples/opensora_pku/scripts/text_condition/multi-devices/{train_video3d_nx480p_zero2.sh => train_t2v_stage1.sh} (55%) rename examples/opensora_pku/scripts/text_condition/multi-devices/{train_video3d_93x720p_zero2_sp.sh => train_t2v_stage2.sh} (50%) create mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_29x720p_zero2_sp.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_29x720p_zero2_sp_val.sh diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_image3d_480p_zero2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_image3d_480p_zero2.sh deleted file mode 100644 index 360a0454e4..0000000000 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_image3d_480p_zero2.sh +++ /dev/null @@ -1,47 +0,0 @@ -# Stage 2: 1x480p -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2i-image3d-1x480p_zero2/parallel_logs" \ - opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V-ROPE-L/122 \ - --text_encoder_name google/mt5-xxl \ - --cache_dir "./" \ - --dataset t2v \ - --data "scripts/train_data/merge_data.txt" \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path "LanguageBind/Open-Sora-Plan-v1.2.0/vae" \ - --sample_rate 1 \ - --num_frames 1 \ - --max_height 480 \ - --max_width 640 \ - --interpolation_scale_t 1.0 \ - --interpolation_scale_h 1.0 \ - --interpolation_scale_w 1.0 \ - --attention_mode xformers \ - --gradient_checkpointing \ - --train_batch_size=8 \ - --dataloader_num_workers 20 \ - --gradient_accumulation_steps=1 \ - --max_train_steps=1000000 \ - --start_learning_rate=1e-4 \ - --lr_scheduler="constant" \ - --seed=10 \ - --lr_warmup_steps=500 \ - --precision="bf16" \ - --checkpointing_steps=2000 \ - --output_dir="t2i-image3d-1x480p_zero2/" \ - --model_max_length 512 \ - --use_image_num 0 \ - --snr_gamma 5.0 \ - --use_ema True\ - --ema_start_step 0 \ - --enable_tiling \ - --tile_overlap_factor 0.0 \ - --pretrained "path/to/pretrained/1x240p/ckpt" \ - --clip_grad True \ - --max_grad_norm 1.0 \ - --use_rope \ - --noise_offset 0.02 \ - --use_parallel True \ - --parallel_mode "zero" \ - --zero_stage 2 \ - --max_device_memory "59GB" \ - --jit_syntax_level "lax" \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_nx480p_zero2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh similarity index 55% rename from examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_nx480p_zero2.sh rename to examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh index ff1af4c51c..e46ce5f429 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_nx480p_zero2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh @@ -1,56 +1,63 @@ -# Stage 3: 29x480p -NUM_FRAME=29 -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video3d-${NUM_FRAME}x480p_zero2/parallel_logs" \ +# Stage 1: 1x320x320 +NUM_FRAME=1 +WIDTH=320 +HEIGHT=320 +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V-ROPE-L/122 \ - --text_encoder_name google/mt5-xxl \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 google/mt5-xxl \ --cache_dir "./" \ --dataset t2v \ - --data "scripts/train_data/merge_data_mixkit.txt" \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path "LanguageBind/Open-Sora-Plan-v1.2.0/vae" \ + --data "scripts/train_data/image_data_v1_2.txt" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --sample_rate 1 \ --num_frames ${NUM_FRAME} \ - --max_height 480 \ - --max_width 640 \ + --max_height ${HEIGHT} \ + --max_width ${WIDTH} \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ - --attention_mode xformers \ --gradient_checkpointing \ --train_batch_size=1 \ --dataloader_num_workers 8 \ --gradient_accumulation_steps=1 \ --max_train_steps=1000000 \ - --start_learning_rate=1e-4 \ + --start_learning_rate=2e-5 \ --lr_scheduler="constant" \ --seed=10 \ - --lr_warmup_steps=500 \ + --lr_warmup_steps=0 \ --precision="bf16" \ --checkpointing_steps=1000 \ - --output_dir="t2v-video3d-${NUM_FRAME}x480p_zero2/" \ + --output_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/" \ --model_max_length 512 \ --use_image_num 0 \ --cfg 0.1 \ --snr_gamma 5.0 \ + --rescale_betas_zero_snr \ --use_ema True\ --ema_start_step 0 \ --enable_tiling \ --tile_overlap_factor 0.125 \ --clip_grad True \ --max_grad_norm 1.0 \ - --use_rope \ --noise_offset 0.02 \ --enable_stable_fp32 True\ --ema_decay 0.999 \ --speed_factor 1.0 \ - --drop_short_ratio 1.0 \ - --pretrained "LanguageBind/Open-Sora-Plan-v1.2.0/1x480p" \ + --drop_short_ratio 0.0 \ --use_parallel True \ --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ --jit_syntax_level "lax" \ - --dataset_sink_mode True \ - --num_no_recompute 18 \ - # --group_frame \ + --dataset_sink_mode False \ + --prediction_type "v_prediction" \ + --hw_stride 32 \ + --sparse1d \ + --sparse_n 4 \ + --train_fps 16 \ + --trained_data_global_step 0 \ + --group_data \ + --mode 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_93x720p_zero2_sp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh similarity index 50% rename from examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_93x720p_zero2_sp.sh rename to examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index 1aa63f71cb..7b3bc691bf 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_93x720p_zero2_sp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -1,57 +1,65 @@ -# Stage 5: 93x720p +# Stage 2: 93x320x320 NUM_FRAME=93 -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video3d-${NUM_FRAME}x720p_zero2_sp/parallel_logs" \ +WIDTH=320 +HEIGHT=320 +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V-ROPE-L/122 \ - --text_encoder_name google/mt5-xxl \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 google/mt5-xxl \ --cache_dir "./" \ --dataset t2v \ - --data "scripts/train_data/merge_data_mixkit.txt" \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path "LanguageBind/Open-Sora-Plan-v1.2.0/vae" \ + --data "scripts/train_data/video_data_v1_2.txt" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --sample_rate 1 \ --num_frames ${NUM_FRAME} \ - --max_height 720 \ - --max_width 1280 \ + --max_height ${HEIGHT} \ + --max_width ${WIDTH} \ --interpolation_scale_t 1.0 \ - --interpolation_scale_h 1.5 \ - --interpolation_scale_w 2.0 \ - --attention_mode xformers \ + --interpolation_scale_h 1.0 \ + --interpolation_scale_w 1.0 \ + --gradient_checkpointing \ --train_batch_size=1 \ --dataloader_num_workers 8 \ --gradient_accumulation_steps=1 \ --max_train_steps=1000000 \ - --start_learning_rate=1e-4 \ + --start_learning_rate=2e-5 \ --lr_scheduler="constant" \ --seed=10 \ - --lr_warmup_steps=500 \ + --lr_warmup_steps=0 \ --precision="bf16" \ --checkpointing_steps=1000 \ - --output_dir="t2v-video3d-${NUM_FRAME}x720p_zero2_sp/" \ + --output_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/" \ --model_max_length 512 \ --use_image_num 0 \ --cfg 0.1 \ --snr_gamma 5.0 \ + --rescale_betas_zero_snr \ --use_ema True\ --ema_start_step 0 \ - --clip_grad True \ - --max_grad_norm 1.0 \ --enable_tiling \ --tile_overlap_factor 0.125 \ - --use_rope \ + --clip_grad True \ + --max_grad_norm 1.0 \ --noise_offset 0.02 \ --enable_stable_fp32 True\ --ema_decay 0.999 \ --speed_factor 1.0 \ - --drop_short_ratio 1.0 \ - --pretrained "LanguageBind/Open-Sora-Plan-v1.2.0/29x720p" \ + --drop_short_ratio 0.0 \ --use_parallel True \ --parallel_mode "zero" \ --zero_stage 2 \ - --sp_size 8 \ - --train_sp_batch_size 1 \ --max_device_memory "59GB" \ --jit_syntax_level "lax" \ - --dataset_sink_mode True \ - --gradient_checkpointing \ - # --group_frame \ + --dataset_sink_mode False \ + --prediction_type "v_prediction" \ + --hw_stride 32 \ + --sparse1d \ + --sparse_n 4 \ + --train_fps 16 \ + --trained_data_global_step 0 \ + --group_data \ + --mode 1 \ + # --sp_size 8 \ + # --train_sp_batch_size 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh new file mode 100644 index 0000000000..02d71499f9 --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -0,0 +1,65 @@ +# Stage 3: 93x480x480 (480x480, 640x352, 352x640) +NUM_FRAME=93 +WIDTH=480 +HEIGHT=480 +ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ + opensora/train/train_t2v_diffusers.py \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 google/mt5-xxl \ + --cache_dir "./" \ + --dataset t2v \ + --data "scripts/train_data/video_data_v1_2.txt" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --sample_rate 1 \ + --num_frames ${NUM_FRAME} \ + --max_height ${HEIGHT} \ + --max_width ${WIDTH} \ + --interpolation_scale_t 1.0 \ + --interpolation_scale_h 1.0 \ + --interpolation_scale_w 1.0 \ + --gradient_checkpointing \ + --train_batch_size=1 \ + --dataloader_num_workers 8 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --start_learning_rate=1e-5 \ + --lr_scheduler="constant" \ + --seed=10 \ + --lr_warmup_steps=0 \ + --precision="bf16" \ + --checkpointing_steps=1000 \ + --output_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/" \ + --model_max_length 512 \ + --use_image_num 0 \ + --cfg 0.1 \ + --snr_gamma 5.0 \ + --rescale_betas_zero_snr \ + --use_ema True\ + --ema_start_step 0 \ + --enable_tiling \ + --tile_overlap_factor 0.125 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --noise_offset 0.02 \ + --enable_stable_fp32 True\ + --ema_decay 0.999 \ + --speed_factor 1.0 \ + --drop_short_ratio 0.0 \ + --use_parallel True \ + --parallel_mode "zero" \ + --zero_stage 2 \ + --max_device_memory "59GB" \ + --jit_syntax_level "lax" \ + --dataset_sink_mode False \ + --prediction_type "v_prediction" \ + --hw_stride 32 \ + --sparse1d \ + --sparse_n 4 \ + --train_fps 16 \ + --trained_data_global_step 0 \ + --group_data \ + --mode 1 \ + # --sp_size 8 \ + # --train_sp_batch_size 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_29x720p_zero2_sp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_29x720p_zero2_sp.sh deleted file mode 100644 index 5e0fe11fb1..0000000000 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_29x720p_zero2_sp.sh +++ /dev/null @@ -1,57 +0,0 @@ -# Stage 4: 29x720p -NUM_FRAME=29 -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video3d-${NUM_FRAME}x720p_zero2_sp/parallel_logs" \ - opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V-ROPE-L/122 \ - --text_encoder_name google/mt5-xxl \ - --cache_dir "./" \ - --dataset t2v \ - --data "scripts/train_data/merge_data_mixkit.txt" \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path "LanguageBind/Open-Sora-Plan-v1.2.0/vae" \ - --sample_rate 1 \ - --num_frames ${NUM_FRAME} \ - --max_height 720 \ - --max_width 1280 \ - --interpolation_scale_t 1.0 \ - --interpolation_scale_h 1.5 \ - --interpolation_scale_w 2.0 \ - --attention_mode xformers \ - --train_batch_size=1 \ - --dataloader_num_workers 8 \ - --gradient_accumulation_steps=1 \ - --max_train_steps=1000000 \ - --start_learning_rate=1e-4 \ - --lr_scheduler="constant" \ - --seed=10 \ - --lr_warmup_steps=500 \ - --precision="bf16" \ - --checkpointing_steps=1000 \ - --output_dir="t2v-video3d-${NUM_FRAME}x720p_zero2_sp/" \ - --model_max_length 512 \ - --use_image_num 0 \ - --cfg 0.1 \ - --snr_gamma 5.0 \ - --use_ema True\ - --ema_start_step 0 \ - --clip_grad True \ - --max_grad_norm 1.0 \ - --enable_tiling \ - --tile_overlap_factor 0.125 \ - --use_rope \ - --noise_offset 0.02 \ - --enable_stable_fp32 True\ - --ema_decay 0.999 \ - --speed_factor 1.0 \ - --drop_short_ratio 1.0 \ - --pretrained "LanguageBind/Open-Sora-Plan-v1.2.0/29x480p" \ - --use_parallel True \ - --parallel_mode "zero" \ - --zero_stage 2 \ - --sp_size 8 \ - --train_sp_batch_size 1 \ - --max_device_memory "59GB" \ - --jit_syntax_level "lax" \ - --dataset_sink_mode True \ - # --gradient_checkpointing \ - # --group_frame \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_29x720p_zero2_sp_val.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_29x720p_zero2_sp_val.sh deleted file mode 100644 index cc600f2609..0000000000 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_video3d_29x720p_zero2_sp_val.sh +++ /dev/null @@ -1,61 +0,0 @@ -# Stage 4: 29x720p training with validation -NUM_FRAME=29 -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video3d-${NUM_FRAME}x720p_zero2_sp/parallel_logs" \ - opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V-ROPE-L/122 \ - --text_encoder_name google/mt5-xxl \ - --cache_dir "./" \ - --dataset t2v \ - --data "scripts/train_data/merge_data_train.txt" \ - --val_data "scripts/train_data/merge_data_val.txt" \ - --validate True \ - --val_batch_size 1 \ - --val_interval 1 \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path "LanguageBind/Open-Sora-Plan-v1.2.0/vae" \ - --sample_rate 1 \ - --num_frames ${NUM_FRAME} \ - --max_height 720 \ - --max_width 1280 \ - --interpolation_scale_t 1.0 \ - --interpolation_scale_h 1.5 \ - --interpolation_scale_w 2.0 \ - --attention_mode xformers \ - --train_batch_size=1 \ - --dataloader_num_workers 8 \ - --gradient_accumulation_steps=1 \ - --max_train_steps=1000000 \ - --start_learning_rate=1e-4 \ - --lr_scheduler="constant" \ - --seed=10 \ - --lr_warmup_steps=500 \ - --precision="bf16" \ - --checkpointing_steps=1000 \ - --output_dir="t2v-video3d-${NUM_FRAME}x720p_zero2_sp/" \ - --model_max_length 512 \ - --use_image_num 0 \ - --cfg 0.1 \ - --snr_gamma 5.0 \ - --use_ema True\ - --ema_start_step 0 \ - --clip_grad True \ - --max_grad_norm 1.0 \ - --enable_tiling \ - --tile_overlap_factor 0.125 \ - --use_rope \ - --noise_offset 0.02 \ - --enable_stable_fp32 True\ - --ema_decay 0.999 \ - --speed_factor 1.0 \ - --drop_short_ratio 1.0 \ - --pretrained "LanguageBind/Open-Sora-Plan-v1.2.0/29x480p" \ - --use_parallel True \ - --parallel_mode "zero" \ - --zero_stage 2 \ - --sp_size 8 \ - --train_sp_batch_size 1 \ - --max_device_memory "59GB" \ - --jit_syntax_level "lax" \ - --dataset_sink_mode True \ - # --gradient_checkpointing \ - # --group_frame \ From 870e2b2856dbc2d6c46b7736e5fb5fbd6ac7b667 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:35:21 +0800 Subject: [PATCH 014/133] fix syntax error for graph mode --- .../models/diffusion/opensora/modeling_opensora.py | 12 +++++++----- .../opensora/models/diffusion/opensora/modules.py | 2 -- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index b91a089da7..d4d2d60938 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -134,6 +134,10 @@ def _init_patched_inputs(self): pad_mode="pad", ) + # save as attributes used in construct + self.patch_size_t = self.config.patch_size_t + self.patch_size = self.config.patch_size + def recompute(self, b): if not b._has_config_recompute: b.recompute(parallel_optimizer_comm_recompute=True) @@ -379,12 +383,10 @@ def construct( encoder_attention_mask = self.get_attention_mask(encoder_attention_mask) # if use bool mask # 1. Input - frame = ( - ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t - ) # patchfy + frame = ((frame - 1) // self.patch_size_t + 1) if frame % 2 == 1 else frame // self.patch_size_t # patchfy height, width = ( - hidden_states.shape[-2] // self.config.patch_size, - hidden_states.shape[-1] // self.config.patch_size, + hidden_states.shape[-2] // self.patch_size, + hidden_states.shape[-1] // self.patch_size, ) hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index bb67c58360..8a0aaa8de1 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -268,8 +268,6 @@ def __call__( frame: int = 8, height: int = 16, width: int = 16, - *args, - **kwargs, ) -> ms.Tensor: residual = hidden_states From 36b561a74db93e7c11e6f41cb4b780fd4ca96b08 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:49:24 +0800 Subject: [PATCH 015/133] update import path & fix ci error --- .../opensora/sample/pipeline_opensora.py | 133 ++++---- .../opensora/utils/sample_utils.py | 318 +++++++++--------- 2 files changed, 239 insertions(+), 212 deletions(-) diff --git a/examples/opensora_pku/opensora/sample/pipeline_opensora.py b/examples/opensora_pku/opensora/sample/pipeline_opensora.py index 59a408b533..a642808dca 100644 --- a/examples/opensora_pku/opensora/sample/pipeline_opensora.py +++ b/examples/opensora_pku/opensora/sample/pipeline_opensora.py @@ -1,46 +1,38 @@ -import html import inspect import logging -import math -import re -import urllib.parse as ul -from typing import Callable, List, Optional, Tuple, Union, Dict +from typing import Callable, List, Optional, Union from opensora.acceleration.communications import AllGather from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info +from transformers import CLIPTokenizer, MT5Tokenizer import mindspore as ms from mindspore import mint, ops -from mindone.diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from mindone.diffusers.utils import BACKENDS_MAPPING, deprecate, is_bs4_available, is_ftfy_available, BaseOutput -from mindone.diffusers import AutoencoderKL -from mindone.diffusers import DDPMScheduler, FlowMatchEulerDiscreteScheduler +from mindone.diffusers import AutoencoderKL, DDPMScheduler, FlowMatchEulerDiscreteScheduler +from mindone.diffusers.pipelines.pipeline_utils import DiffusionPipeline +from mindone.diffusers.utils import BaseOutput from mindone.diffusers.utils.mindspore_utils import randn_tensor +from mindone.transformers import CLIPTextModelWithProjection, T5EncoderModel + # from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback #TODO:TBD -from mindone.transformers import CLIPTextModelWithProjection, T5EncoderModel -from transformers import CLIPTokenizer, CLIPImageProcessor, MT5Tokenizer logger = logging.getLogger(__name__) -if is_bs4_available(): - from bs4 import BeautifulSoup -if is_ftfy_available(): - import ftfy - from dataclasses import dataclass + import numpy as np import PIL - -from examples.opensora_pku.opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V_v1_3 +from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V_v1_3 @dataclass class OpenSoraPipelineOutput(BaseOutput): videos: Union[List[ms.Tensor], List[PIL.Image.Image], np.ndarray] + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ @@ -55,6 +47,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -113,7 +106,6 @@ def retrieve_timesteps( class OpenSoraPipeline(DiffusionPipeline): - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [ "text_encoder_2", @@ -152,7 +144,6 @@ def __init__( ) self.all_gather = None if not get_sequence_parallel_state() else AllGather() - @ms.jit # FIXME: on ms2.3, in pynative mode, text encoder's output has nan problem. def text_encoding_func(self, text_encoder, input_ids, attention_mask): return ops.stop_gradient(text_encoder(input_ids, attention_mask=attention_mask)) @@ -160,7 +151,7 @@ def text_encoding_func(self, text_encoder, input_ids, attention_mask): def encode_prompt( self, prompt: str, - dtype = None, + dtype=None, num_samples_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, @@ -244,9 +235,9 @@ def encode_prompt( text_input_ids = ms.Tensor(text_inputs.input_ids) untruncated_ids = ms.Tensor(tokenizer(prompt, padding="longest", return_tensors=None).input_ids) - if ( - untruncated_ids.shape[-1] > text_input_ids.shape[-1] or - (untruncated_ids.shape[-1] == text_input_ids.shape[-1] and not ops.equal(text_input_ids, untruncated_ids).all()) + if untruncated_ids.shape[-1] > text_input_ids.shape[-1] or ( + untruncated_ids.shape[-1] == text_input_ids.shape[-1] + and not ops.equal(text_input_ids, untruncated_ids).all() ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( @@ -267,7 +258,6 @@ def encode_prompt( prompt_embeds = prompt_embeds.to(dtype=dtype) - bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(num_samples_per_prompt, axis=1) @@ -305,8 +295,14 @@ def encode_prompt( uncond_text_inputs = ms.Tensor(uncond_input.input_ids) negative_prompt_attention_mask = ms.Tensor(uncond_input.attention_mask) - negative_prompt_embeds = self.text_encoding_func(text_encoder, uncond_text_inputs, attention_mask=negative_prompt_attention_mask) - negative_prompt_embeds = negative_prompt_embeds[0] if isinstance(negative_prompt_embeds, (list, tuple)) else negative_prompt_embeds + negative_prompt_embeds = self.text_encoding_func( + text_encoder, uncond_text_inputs, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = ( + negative_prompt_embeds[0] + if isinstance(negative_prompt_embeds, (list, tuple)) + else negative_prompt_embeds + ) if text_encoder_index == 1: negative_prompt_embeds = negative_prompt_embeds.unsqueeze(1) # b d -> b 1 d for clip @@ -318,7 +314,7 @@ def encode_prompt( negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype) - negative_prompt_embeds = negative_prompt_embeds.repeat(num_samples_per_prompt, axis = 1) + negative_prompt_embeds = negative_prompt_embeds.repeat(num_samples_per_prompt, axis=1) negative_prompt_embeds = negative_prompt_embeds.view((batch_size * num_samples_per_prompt, seq_len, -1)) else: negative_prompt_embeds = None @@ -370,7 +366,8 @@ def check_inputs( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, " + + f"but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -424,11 +421,13 @@ def check_inputs( ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None): + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None + ): shape = ( batch_size, num_channels_latents, - (int(num_frames) - 1) // self.vae.vae_scale_factor[0] + 1, + (int(num_frames) - 1) // self.vae.vae_scale_factor[0] + 1, int(height) // self.vae.vae_scale_factor[1], int(width) // self.vae.vae_scale_factor[2], ) @@ -456,7 +455,9 @@ def prepare_parallel_latent(self, video_states): if padding_needed > 0: logger.debug("Doing video padding") # B, C, T, H, W -> B, C, T', H, W - video_states = mint.nn.functional.pad(video_states, (0, 0, 0, 0, 0, padding_needed), mode="constant", value=0) + video_states = mint.nn.functional.pad( + video_states, (0, 0, 0, 0, 0, padding_needed), mode="constant", value=0 + ) b, _, f, h, w = video_states.shape temp_attention_mask = mint.ones((b, f), ms.int32) @@ -465,7 +466,7 @@ def prepare_parallel_latent(self, video_states): assert video_states.shape[2] % sp_size == 0 video_states = ops.chunk(video_states, sp_size, 2)[index] return video_states, temp_attention_mask - + @property def guidance_scale(self): return self._guidance_scale @@ -510,28 +511,28 @@ def __call__( negative_prompt_attention_mask_2: Optional[ms.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback_on_step_end: Optional[Callable[[int, int, ms.Tensor], None]] = None, # Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] + callback_on_step_end: Optional[ + Callable[[int, int, ms.Tensor], None] + ] = None, # Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] callback_on_step_end_tensor_inputs: List[str] = ["latents"], guidance_rescale: float = 0.0, max_sequence_length: int = 512, ): - - # TODO - if hasattr(callback_on_step_end, 'tensor_inputs'): + # TODO + if hasattr(callback_on_step_end, "tensor_inputs"): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - # callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. default height and width num_frames = num_frames or (self.transformer.config.sample_size_t - 1) * self.vae.vae_scale_factor[0] + 1 height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2] - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, - num_frames, + num_frames, height, width, negative_prompt, @@ -557,7 +558,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # 3. Encode input prompt ( @@ -578,7 +578,7 @@ def __call__( max_sequence_length=max_sequence_length, text_encoder_index=0, ) - + if self.tokenizer_2 is not None: ( prompt_embeds_2, @@ -622,7 +622,7 @@ def __call__( latents = self.prepare_latents( batch_size * num_samples_per_prompt, num_channels_latents, - (num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames, + (num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames, height, width, prompt_embeds.dtype, @@ -675,7 +675,7 @@ def __call__( for i, t in enumerate(timesteps): if self.interrupt: continue - + # expand the latents if we are doing classifier free guidance latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): @@ -697,8 +697,8 @@ def __call__( prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2: prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d #OFFICIAL VER. DONT KNOW WHY - # prompt_embeds_2 = prompt_embeds_2.unsqueeze(1) # - + # prompt_embeds_2 = prompt_embeds_2.unsqueeze(1) # + attention_mask = ops.ones_like(latent_model_input)[:, 0] if temp_attention_mask is not None: # temp_attention_mask shape (bs, t), 1 means to keep, 0 means to discard @@ -710,28 +710,32 @@ def __call__( # ==================make sp===================================== if get_sequence_parallel_state(): - attention_mask = attention_mask.repeat(world_size, axis = 1) + attention_mask = attention_mask.repeat(world_size, axis=1) # ==================make sp===================================== noise_pred = ops.stop_gradient( - self.transformer( - latent_model_input, # (b c t h w) - attention_mask=attention_mask, + self.transformer( + latent_model_input, # (b c t h w) + attention_mask=attention_mask, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, timestep=current_timestep, - pooled_projections=prompt_embeds_2, # UNUSED!!!! + pooled_projections=prompt_embeds_2, # UNUSED!!!! return_dict=False, ) - ) # b,c,t,h,w - assert not ops.any(ops.isnan(noise_pred.float())) + ) # b,c,t,h,w + assert not ops.any(ops.isnan(noise_pred.float())) # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - if self.do_classifier_free_guidance and guidance_rescale > 0.0 and not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + if ( + self.do_classifier_free_guidance + and guidance_rescale > 0.0 + and not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler) + ): # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) @@ -761,7 +765,7 @@ def __call__( # full_shape = [latents_shape[0] * world_size] + latents_shape[1:] # # b*sp c t//sp h w # all_latents = ops.zeros(full_shape, dtype=latents.dtype) all_latents = self.all_gather(latents) - latents_list = mint.chunk(all_latents, world_size, axis = 0) + latents_list = mint.chunk(all_latents, world_size, axis=0) latents = ops.cat(latents_list, axis=2) # ==================make sp===================================== @@ -775,11 +779,10 @@ def __call__( # self.maybe_free_model_hooks() if not return_dict: - return (videos, ) + return (videos,) return OpenSoraPipelineOutput(videos=videos) - # def decode_latents(self, latents): # print(f'before vae decode {latents.shape}', ops.max(latents).item(), ops.min(latents).item(), ops.mean(latents).item(), ops.std(latents).item()) # video = self.vae.decode(latents.to(self.vae.vae.dtype)) # (b t c h w) @@ -788,9 +791,21 @@ def __call__( # return video def decode_latents_per_sample(self, latents): - print(f'before vae decode {latents.shape}', latents.max().item(), latents.min().item(), latents.mean().item(), latents.std().item()) + print( + f"before vae decode {latents.shape}", + latents.max().item(), + latents.min().item(), + latents.mean().item(), + latents.std().item(), + ) video = self.vae.decode(latents).to(ms.float32) # (b t c h w) - print(f'after vae decode {video.shape}', video.max().item(), video.min().item(), video.mean().item(), video.std().item()) + print( + f"after vae decode {video.shape}", + video.max().item(), + video.min().item(), + video.mean().item(), + video.std().item(), + ) video = ops.clip_by_value((video / 2.0 + 0.5), clip_value_min=0.0, clip_value_max=1.0).permute(0, 1, 3, 4, 2) return video # b t h w c diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index 31733a5419..ec692f4eb7 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -2,107 +2,99 @@ import glob import logging import os -import sys import time import numpy as np import pandas as pd import yaml -from PIL import Image -from tqdm import tqdm - -import mindspore as ms -from mindspore import nn - - -from mindone.diffusers import ( - DDIMScheduler, DDPMScheduler, PNDMScheduler, - EulerDiscreteScheduler, DPMSolverMultistepScheduler, - HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, - DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler, - DPMSolverSinglestepScheduler, #CogVideoXDDIMScheduler, - FlowMatchEulerDiscreteScheduler - ) - from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info from opensora.dataset.text_dataset import create_dataloader -from opensora.utils.message_utils import print_banner -from opensora.utils.ms_utils import init_env -from opensora.utils.utils import _check_cfgs_in_parser, get_precision from opensora.models.causalvideovae import ae_stride_config, ae_wrapper + # from opensora.sample.caption_refiner import OpenSoraCaptionRefiner from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate -from examples.opensora_pku.opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V_v1_3 -from examples.opensora_pku.opensora.models.diffusion.opensora.modules import Attention, LayerNorm -from opensora.sample.pipeline_opensora import OpenSoraPipeline from opensora.models.diffusion.common import PatchEmbed2D +from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V_v1_3 +from opensora.models.diffusion.opensora.modules import Attention, LayerNorm +from opensora.sample.pipeline_opensora import OpenSoraPipeline +from opensora.utils.message_utils import print_banner +from opensora.utils.utils import _check_cfgs_in_parser, get_precision +from PIL import Image +from tqdm import tqdm +from transformers import AutoTokenizer -from mindone.diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings -from mindone.diffusers import ( - DDIMScheduler, DDPMScheduler, PNDMScheduler, - EulerDiscreteScheduler, DPMSolverMultistepScheduler, - HeunDiscreteScheduler, EulerAncestralDiscreteScheduler, - DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler, - DPMSolverSinglestepScheduler, #CogVideoXDDIMScheduler, - FlowMatchEulerDiscreteScheduler - ) -from mindone.transformers import T5EncoderModel, MT5EncoderModel, CLIPTextModelWithProjection -from transformers import AutoTokenizer, MT5Tokenizer +import mindspore as ms +from mindspore import nn +from mindone.diffusers import DPMSolverSinglestepScheduler # CogVideoXDDIMScheduler, +from mindone.diffusers import ( + DDIMScheduler, + DDPMScheduler, + DEISMultistepScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + FlowMatchEulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + PNDMScheduler, +) +from mindone.diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from mindone.diffusers.training_utils import set_seed +from mindone.transformers import CLIPTextModelWithProjection, MT5EncoderModel, T5EncoderModel from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import str2bool from mindone.utils.params import count_params from mindone.visualize.videos import save_videos -from mindone.diffusers.training_utils import set_seed logger = logging.getLogger(__name__) + def get_scheduler(args): kwargs = dict( - prediction_type=args.prediction_type, - rescale_betas_zero_snr=args.rescale_betas_zero_snr, - timestep_spacing="trailing" if args.rescale_betas_zero_snr else 'leading', + prediction_type=args.prediction_type, + rescale_betas_zero_snr=args.rescale_betas_zero_snr, + timestep_spacing="trailing" if args.rescale_betas_zero_snr else "leading", ) if args.v1_5_scheduler: - kwargs['beta_start'] = 0.00085 - kwargs['beta_end'] = 0.0120 - kwargs['beta_schedule'] = "scaled_linear" - if args.sample_method == 'DDIM': + kwargs["beta_start"] = 0.00085 + kwargs["beta_end"] = 0.0120 + kwargs["beta_schedule"] = "scaled_linear" + if args.sample_method == "DDIM": scheduler_cls = DDIMScheduler - kwargs['clip_sample'] = False - elif args.sample_method == 'EulerDiscrete': + kwargs["clip_sample"] = False + elif args.sample_method == "EulerDiscrete": scheduler_cls = EulerDiscreteScheduler - elif args.sample_method == 'DDPM': + elif args.sample_method == "DDPM": scheduler_cls = DDPMScheduler - kwargs['clip_sample'] = False - elif args.sample_method == 'DPMSolverMultistep': + kwargs["clip_sample"] = False + elif args.sample_method == "DPMSolverMultistep": scheduler_cls = DPMSolverMultistepScheduler - elif args.sample_method == 'DPMSolverSinglestep': + elif args.sample_method == "DPMSolverSinglestep": scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_method == 'PNDM': + elif args.sample_method == "PNDM": scheduler_cls = PNDMScheduler - kwargs.pop('rescale_betas_zero_snr', None) - elif args.sample_method == 'HeunDiscrete': ######## + kwargs.pop("rescale_betas_zero_snr", None) + elif args.sample_method == "HeunDiscrete": scheduler_cls = HeunDiscreteScheduler - elif args.sample_method == 'EulerAncestralDiscrete': + elif args.sample_method == "EulerAncestralDiscrete": scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_method == 'DEISMultistep': + elif args.sample_method == "DEISMultistep": scheduler_cls = DEISMultistepScheduler - kwargs.pop('rescale_betas_zero_snr', None) - elif args.sample_method == 'KDPM2AncestralDiscrete': ######### + kwargs.pop("rescale_betas_zero_snr", None) + elif args.sample_method == "KDPM2AncestralDiscrete": scheduler_cls = KDPM2AncestralDiscreteScheduler # elif args.sample_method == 'CogVideoX': # scheduler_cls = CogVideoXDDIMScheduler - elif args.sample_method == 'FlowMatchEulerDiscrete': + elif args.sample_method == "FlowMatchEulerDiscrete": scheduler_cls = FlowMatchEulerDiscreteScheduler kwargs = {} else: - raise NameError(f'Unsupport sample_method {args.sample_method}') + raise NameError(f"Unsupport sample_method {args.sample_method}") scheduler = scheduler_cls(**kwargs) return scheduler - def prepare_pipeline(args): # VAE model initiate and weight loading print_banner("vae init") @@ -127,60 +119,55 @@ def prepare_pipeline(args): vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor - ## use amp level O2 for causal 3D VAE with bfloat16 or float16 + # use amp level O2 for causal 3D VAE with bfloat16 or float16 if vae_dtype == ms.float16: custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] else: custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}") vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells) - + vae.set_train(False) for param in vae.get_parameters(): # freeze vae - param.requires_grad = False - + param.requires_grad = False + if args.decode_latents: print("To decode latents directly, skipped loading text endoers and transformer") return vae - + # Build text encoders print_banner("text encoder init") text_encoder_dtype = get_precision(args.text_encoder_precision) - if 'mt5' in args.text_encoder_name_1: + if "mt5" in args.text_encoder_name_1: text_encoder_1, loading_info = MT5EncoderModel.from_pretrained( - args.text_encoder_name_1, - cache_dir=args.cache_dir, + args.text_encoder_name_1, + cache_dir=args.cache_dir, output_loading_info=True, mindspore_dtype=text_encoder_dtype, - use_safetensors=True - ) + use_safetensors=True, + ) # loading_info.pop("unexpected_keys") # decoder weights are ignored # logger.info(f"Loaded MT5 Encoder: {loading_info}") - text_encoder_1 = text_encoder_1.set_train(False) + text_encoder_1 = text_encoder_1.set_train(False) else: text_encoder_1 = T5EncoderModel.from_pretrained( - args.text_encoder_name_1, cache_dir=args.cache_dir, - mindspore_dtype=text_encoder_dtype - ).set_train(False) - tokenizer_1 = AutoTokenizer.from_pretrained( - args.text_encoder_name_1, cache_dir=args.cache_dir - ) + args.text_encoder_name_1, cache_dir=args.cache_dir, mindspore_dtype=text_encoder_dtype + ).set_train(False) + tokenizer_1 = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir) if args.text_encoder_name_2 is not None: text_encoder_2, loading_info = CLIPTextModelWithProjection.from_pretrained( - args.text_encoder_name_2, - cache_dir=args.cache_dir, + args.text_encoder_name_2, + cache_dir=args.cache_dir, mindspore_dtype=text_encoder_dtype, output_loading_info=True, use_safetensors=True, - ) + ) # loading_info.pop("unexpected_keys") # only load text model, ignore vision model # loading_info.pop("mising_keys") # Note: missed keys when loading open-clip models - # logger.info(f"Loaded CLIP Encoder: {loading_info}") + # logger.info(f"Loaded CLIP Encoder: {loading_info}") text_encoder_2 = text_encoder_2.set_train(False) - tokenizer_2 = AutoTokenizer.from_pretrained( - args.text_encoder_name_2, cache_dir=args.cache_dir - ) + tokenizer_2 = AutoTokenizer.from_pretrained(args.text_encoder_name_2, cache_dir=args.cache_dir) else: text_encoder_2, tokenizer_2 = None, None @@ -198,7 +185,7 @@ def prepare_pipeline(args): else: state_dict = None model_version = args.model_path.split("/")[-1] - if (args.version != 'v1_3') and (model_version.split("x")[0][:3] != "any"): + if (args.version != "v1_3") and (model_version.split("x")[0][:3] != "any"): if int(model_version.split("x")[0]) != args.num_frames: logger.warning( f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {model_version.split('x')[0]}" @@ -207,16 +194,20 @@ def prepare_pipeline(args): logger.warning( f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}" ) - elif (args.version == 'v1_3') and (model_version.split("x")[0] == "any93x640x640"): # TODO: currently only release one model + elif (args.version == "v1_3") and ( + model_version.split("x")[0] == "any93x640x640" + ): # TODO: currently only release one model if (args.height % 32 != 0) or (args.width % 32 != 0): logger.warning( - f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}. The resolution of the inference should be a multiple of 32." + f"Detect that the loaded model version is {model_version}, but found a mismatched resolution {args.height}x{args.width}. \ + The resolution of the inference should be a multiple of 32." ) if (args.num_frames - 1) % 4 != 0: logger.warning( - f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {args.num_frames}. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image)" - ) - if args.version == 'v1_3': + f"Detect that the loaded model version is {model_version}, but found a mismatched number of frames {args.num_frames}. \ + Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image)" + ) + if args.version == "v1_3": # TODO # if args.model_type == 'inpaint' or args.model_type == 'i2v': # transformer_model = OpenSoraInpaint_v1_3.from_pretrained( @@ -224,48 +215,54 @@ def prepare_pipeline(args): # device_map=None, mindspore_dtype=weight_dtype # ).set_train(False) # else: - + transformer_model, logging_info = OpenSoraT2V_v1_3.from_pretrained( - args.model_path, + args.model_path, state_dict=state_dict, cache_dir=args.cache_dir, - FA_dtype = FA_dtype, - output_loading_info=True, - ) + FA_dtype=FA_dtype, + output_loading_info=True, + ) logger.info(logging_info) - elif args.version == 'v1_5': - if args.model_type == 'inpaint' or args.model_type == 'i2v': - raise NotImplementedError('Inpainting model is not available in v1_5') + elif args.version == "v1_5": + if args.model_type == "inpaint" or args.model_type == "i2v": + raise NotImplementedError("Inpainting model is not available in v1_5") else: from opensora.models.diffusion.opensora_v1_5.modeling_opensora import OpenSoraT2V_v1_5 + + weight_dtype = ms.float32 transformer_model = OpenSoraT2V_v1_5.from_pretrained( - args.model_path, cache_dir=args.cache_dir, - # device_map=None, - mindspore_dtype=weight_dtype - ) - + args.model_path, + cache_dir=args.cache_dir, + # device_map=None, + mindspore_dtype=weight_dtype, + ) + # Mixed precision dtype = get_precision(args.precision) if args.precision in ["fp16", "bf16"]: if not args.global_bf16: - amp_level = args.amp_level if dtype == ms.float16: - custom_fp32_cells=[LayerNorm, Attention, PatchEmbed2D, nn.SiLU, nn.GELU, PixArtAlphaCombinedTimestepSizeEmbeddings] + custom_fp32_cells = [ + LayerNorm, + Attention, + PatchEmbed2D, + nn.SiLU, + nn.GELU, + PixArtAlphaCombinedTimestepSizeEmbeddings, + ] else: - custom_fp32_cells= [ + custom_fp32_cells = [ nn.MaxPool2d, - nn.MaxPool3d, # do not support bf16 - PatchEmbed2D, # low accuracy if using bf16 + nn.MaxPool3d, # do not support bf16 + PatchEmbed2D, # low accuracy if using bf16 LayerNorm, nn.SiLU, nn.GELU, PixArtAlphaCombinedTimestepSizeEmbeddings, ] transformer_model = auto_mixed_precision( - transformer_model, - amp_level=args.amp_level, - dtype=dtype, - custom_fp32_cells=custom_fp32_cells + transformer_model, amp_level=args.amp_level, dtype=dtype, custom_fp32_cells=custom_fp32_cells ) logger.info( f"Set mixed precision to {args.amp_level} with dtype={args.precision}, custom fp32_cells {custom_fp32_cells}" @@ -274,7 +271,7 @@ def prepare_pipeline(args): logger.info(f"Using global bf16. Force model dtype from {dtype} to ms.bfloat16") dtype = ms.bfloat16 elif args.precision == "fp32": - amp_level = "O0" + pass else: raise ValueError(f"Unsupported precision {args.precision}") transformer_model = transformer_model.set_train(False) @@ -283,7 +280,7 @@ def prepare_pipeline(args): # Build scheduler scheduler = get_scheduler(args) - + # Build inference pipeline # pipeline_class = OpenSoraInpaintPipeline if args.model_type == 'inpaint' or args.model_type == 'i2v' else OpenSoraPipeline pipeline_class = OpenSoraPipeline @@ -293,13 +290,13 @@ def prepare_pipeline(args): text_encoder=text_encoder_1, tokenizer=tokenizer_1, scheduler=scheduler, - transformer=transformer_model, + transformer=transformer_model, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, ) - if args.save_memory: #TODO: Susan comment: I am not sure yet - print('enable_model_cpu_offload AND enable_sequential_cpu_offload AND enable_tiling') + if args.save_memory: # TODO: Susan comment: I am not sure yet + print("enable_model_cpu_offload AND enable_sequential_cpu_offload AND enable_tiling") pipeline.enable_model_cpu_offload() pipeline.enable_sequential_cpu_offload() if not args.enable_tiling: @@ -335,7 +332,8 @@ def prepare_pipeline(args): return pipeline -## See npu_config.py set_npu_env() + +# See npu_config.py set_npu_env() # def init_npu_env(args): # local_rank = int(os.getenv('RANK', 0)) # world_size = int(os.getenv('WORLD_SIZE', 1)) @@ -343,7 +341,7 @@ def prepare_pipeline(args): # args.world_size = world_size # torch_npu.npu.set_device(local_rank) # dist.init_process_group( -# backend='hccl', init_method='env://', +# backend='hccl', init_method='env://', # world_size=world_size, rank=local_rank # ) # if args.sp: @@ -351,7 +349,9 @@ def prepare_pipeline(args): # return args -def run_model_and_save_samples(args, pipeline, rank_id, device_num, save_dir, caption_refiner_model=None, enhance_video_model=None): +def run_model_and_save_samples( + args, pipeline, rank_id, device_num, save_dir, caption_refiner_model=None, enhance_video_model=None +): if args.seed is not None: set_seed(args.seed, rank=rank_id) @@ -431,7 +431,7 @@ def run_model_and_save_samples(args, pipeline, rank_id, device_num, save_dir, ca save_fp = os.path.join(save_dir, file_paths[i_sample]).replace(".npy", f".{args.video_extension}") save_video_data = decode_data[i_sample : i_sample + 1] save_videos(save_video_data, save_fp, loop=0, fps=args.fps) # (b t h w c) - + # Delete files that are no longer needed if os.path.exists(temp_dataset_csv): os.remove(temp_dataset_csv) @@ -440,7 +440,7 @@ def run_model_and_save_samples(args, pipeline, rank_id, device_num, save_dir, ca npy_files = glob.glob(os.path.join(save_dir, "*.npy")) for fp in npy_files: os.remove(fp) - + # TODO # if args.model_type == 'inpaint' or args.model_type == 'i2v': # if not isinstance(args.conditional_pixel_values_path, list): @@ -454,7 +454,7 @@ def run_model_and_save_samples(args, pipeline, rank_id, device_num, save_dir, ca high quality, high aesthetic, {} """ negative_prompt = """ - nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, + nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. """ # positive_prompt = ( @@ -467,18 +467,18 @@ def run_model_and_save_samples(args, pipeline, rank_id, device_num, save_dir, ca # ) def generate(step, data, ext, conditional_pixel_values_path=None, mask_type=None): - - + prompt = [x for x in data["caption"]] if args.caption_refiner is not None: - if args.model_type != 'inpaint' and args.model_type != 'i2v': + if args.model_type != "inpaint" and args.model_type != "i2v": refine_prompt = caption_refiner_model.get_refiner_output(prompt) - print(f'\nOrigin prompt: {prompt}\n->\nRefine prompt: {refine_prompt}') + print(f"\nOrigin prompt: {prompt}\n->\nRefine prompt: {refine_prompt}") prompt = refine_prompt else: - # Due to the current use of LLM as the caption refiner, additional content that is not present in the control image will be added. Therefore, caption refiner is not used in this mode. - print('Caption refiner is not available for inpainting model, use the original prompt...') + # Due to the current use of LLM as the caption refiner, additional content that is not present in the + # control image will be added. Therefore, caption refiner is not used in this mode. + print("Caption refiner is not available for inpainting model, use the original prompt...") time.sleep(3) - # TODO + # TODO # input_prompt = positive_prompt.format(prompt) # if args.model_type == 'inpaint' or args.model_type == 'i2v': # print(f'\nConditional pixel values path: {conditional_pixel_values_path}') @@ -487,8 +487,8 @@ def generate(step, data, ext, conditional_pixel_values_path=None, mask_type=None # mask_type=mask_type, # crop_for_hw=args.crop_for_hw, # max_hxw=args.max_hxw, - # prompt=input_prompt, - # negative_prompt=negative_prompt, + # prompt=input_prompt, + # negative_prompt=negative_prompt, # num_frames=args.num_frames, # height=args.height, # width=args.width, @@ -498,15 +498,13 @@ def generate(step, data, ext, conditional_pixel_values_path=None, mask_type=None # max_sequence_length=args.max_sequence_length, # ).videos # else: - prompt = [x for x in data["caption"]] file_paths = data["file_path"] - input_prompt = positive_prompt.format(prompt[0]) # remove "[]" - saved_prompt1_dict = None - + input_prompt = positive_prompt.format(prompt[0]) # remove "[]" + videos = ( pipeline( - input_prompt, - negative_prompt=negative_prompt, + input_prompt, + negative_prompt=negative_prompt, num_frames=args.num_frames, height=args.height, width=args.width, @@ -561,8 +559,7 @@ def generate(step, data, ext, conditional_pixel_values_path=None, mask_type=None # else: for step, data in tqdm(enumerate(ds_iter), total=dataset_size): generate(step, data, ext) - - + # Delete files that are no longer needed if os.path.exists(temp_dataset_csv): os.remove(temp_dataset_csv) @@ -570,25 +567,38 @@ def generate(step, data, ext, conditional_pixel_values_path=None, mask_type=None def get_args(): parser = argparse.ArgumentParser() - - parser.add_argument("--version", type=str, default='v1_3', choices=['v1_3', 'v1_5']) + + parser.add_argument("--version", type=str, default="v1_3", choices=["v1_3", "v1_5"]) parser.add_argument("--caption_refiner", type=str, default=None, help="caption refiner model path") parser.add_argument("--enhance_video", type=str, default=None) - parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl', help="google/mt5-xxl, DeepFloyd/t5-v1_1-xxl") - parser.add_argument("--text_encoder_name_2", type=str, default=None, help=" openai/clip-vit-large-patch14, (laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)") + parser.add_argument( + "--text_encoder_name_1", type=str, default="DeepFloyd/t5-v1_1-xxl", help="google/mt5-xxl, DeepFloyd/t5-v1_1-xxl" + ) + parser.add_argument( + "--text_encoder_name_2", + type=str, + default=None, + help=" openai/clip-vit-large-patch14, (laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)", + ) parser.add_argument("--num_samples_per_prompt", type=int, default=1) - parser.add_argument('--refine_caption', action='store_true') + parser.add_argument("--refine_caption", action="store_true") # parser.add_argument('--compile', action='store_true') - parser.add_argument("--prediction_type", type=str, default='epsilon', help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.") - parser.add_argument('--rescale_betas_zero_snr', action='store_true') - # parser.add_argument('--local_rank', type=int, default=-1) - # parser.add_argument('--world_size', type=int, default=1) + parser.add_argument( + "--prediction_type", + type=str, + default="epsilon", + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. \ + If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument("--rescale_betas_zero_snr", action="store_true") + # parser.add_argument('--local_rank', type=int, default=-1) + # parser.add_argument('--world_size', type=int, default=1) # parser.add_argument('--sp', action='store_true') - parser.add_argument('--v1_5_scheduler', action='store_true') - parser.add_argument('--conditional_pixel_values_path', type=str, default=None) - parser.add_argument('--mask_type', type=str, default=None) - parser.add_argument('--crop_for_hw', action='store_true') - parser.add_argument('--max_hxw', type=int, default=236544) #236544=512x462???? + parser.add_argument("--v1_5_scheduler", action="store_true") + parser.add_argument("--conditional_pixel_values_path", type=str, default=None) + parser.add_argument("--mask_type", type=str, default=None) + parser.add_argument("--crop_for_hw", action="store_true") + parser.add_argument("--max_hxw", type=int, default=236544) # 236544=512x462???? parser.add_argument( "--config", @@ -703,7 +713,9 @@ def get_args(): parser.add_argument( "--video_extension", default="mp4", choices=["gif", "mp4"], help="The file extension to save videos" ) - parser.add_argument("--model_type", type=str, default="dit", choices=["dit", "udit", "latte", 't2v', 'inpaint', 'i2v']) + parser.add_argument( + "--model_type", type=str, default="dit", choices=["dit", "udit", "latte", "t2v", "inpaint", "i2v"] + ) parser.add_argument("--cache_dir", type=str, default="./") parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not") From 737854d2d4be9e0dd66e6ea3bba7c4212e17f835 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:00:33 +0800 Subject: [PATCH 016/133] update sample script --- examples/opensora_pku/examples/sora.txt | 51 +++++++++++++++++++ .../diffusion/opensora/modeling_opensora.py | 10 ++-- .../opensora/sample/pipeline_opensora.py | 28 +++++----- .../single-device/sample_t2v_93x640.sh | 12 +++-- 4 files changed, 75 insertions(+), 26 deletions(-) create mode 100644 examples/opensora_pku/examples/sora.txt diff --git a/examples/opensora_pku/examples/sora.txt b/examples/opensora_pku/examples/sora.txt new file mode 100644 index 0000000000..192003a08a --- /dev/null +++ b/examples/opensora_pku/examples/sora.txt @@ -0,0 +1,51 @@ +A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, along red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is dampand reflective, creating a mirror effect of thecolorful lights. Many pedestrians walk about. +Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered tree sand dramatic snow capped mountains in the distance,mid afternoon lightwith wispy cloud sand a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field +A movie trailer featuring the adventures ofthe 30 year old spacemanwearing a redwool knitted motorcycle helmet, bluesky, saltdesert, cinematic style, shoton 35mm film, vivid colors.  +Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach.The crashing blue waters create white-tipped waves,while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green +shrubbery covers the cliffs edge. The steep drop from the road down to the beach is adramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway. +Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle.The art style is 3D and realistic,with a focus on lighting and texture.The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and +open mouth. lts pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time.The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image. +A gorgeously rendered papercraft world of a coral reef,rife with colorful fish and sea creatures. +This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird's head is tilted slightly to the side,giving the impression of it looking regal and majestic. The background is blurred,drawing attention to the bird's striking appearance. +Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee. +A young man at his 20s is sitting on a piece of cloud in the sky, reading a book. +A petri dish with a bamboo forest growing within it that has tiny red pandas running around. +The camera rotates around a large stack of vintage televisions all showing different programs-1950s sci-fi movies, horror movies, news, static, a 1970s sitcom, etc, set inside a large New York museum gallery. +3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream,its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest. +Historical footage of California during the gold rush. +A close up view of a glass sphere that has a zen garden within it. There is a small dwarf in the sphere who is raking the zen garden and creating patterns in the sand. +Extreme close up of a 24 year old woman's eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field,vivid colors, cinematic. +A cartoon kangaroo disco dances. +A beautiful homemade video showing the people of Lagos, Nigeria in the year 2056. Shot with a mobile phone camera. +A cat waking up its sleeping owner demanding breakfast.The owner tries to ignore the cat, but the cat tries new tactics and finally the owner pulls out a secret stash of treats from under the pillow to hold the cat off a little longer. +Borneo wildlife on the Kinabatangan River +A Chinese Lunar New Year celebration video with Chinese Dragon. +The camera follows behind a white vintage SUv with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it's tires, the sunlight shines on the Suv as it speeds along the dirt road,casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars orvehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains with a clear blue sky above with wispy clouds. +Reflections in the window of a train traveling through the Tokyo suburbs. +A drone camera circles around a beautiful historic church built on a rocky outcropping along the Amalfi Coast, the view showcases historic and magnificent architectural details and tiered pathways and patios, waves are seen crashing against the rocks below as the view overlooks the horizon of the coastal waters and hilly landscapes of the Amalfi Coast ltaly, several distant people are seen walking and enjoying vistas on patios of the dramatic ocean views, the warm glow of the afternoon sun creates a magical and romantic feeling to the scene, the view is stunning captured with beautiful photography +A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. lts tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock,its claws raised and ready to attack. The crab is brown and spiny,with long legs and antennae. The scene is captured from a wide angle,showing the vastness and depth of the ocean. The wateris clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred,creating a depth of field effect. +A flock of paper airplanes flutters through a dense jungle,weaving around trees as if they were migrating birds. +A beautiful silhouette animation shows a wolf howling at the moon,feeling lonely, untilit finds its pack. +New York City submerged like Atlantis.Fish,whales,sea turtles and sharks swim through the streets of New York. +A litter of golden retriever puppies playing in the snow.Their heads pop out of the snow, covered in. +Tour of an art gallery with many beautiful works of art in different styles. +Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes. +A stop motion animation of a flower growing out of the windowsill of a suburban house. +The story of a robot's life in a cyberpunk setting. +An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt, he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film. +Basketball through hoop then explodes +Archeologists discovera generic plastic chairin the desert,excavating and dusting it with great care +A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table,expression is one of pure joy and happines with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker,the grandmotherwears a light blue blouse adorned with floral patterns,several happy friends and family sitting at the table can be seen celebrating,out of focus.The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood +Step-printing scene of a person running, cinematic film shot in 35mm +Five gray wolf pups frolicking and chasing each other around a remote gravel road, surrounded by grass. The pups run and leap, chasing each other, and nipping at each other, playing. +Tiltshift of a construction site filled with workers, equipment, and heavy machinery. +A giant, towering cloud in the shape of a man looms overthe earth. The cloud man shoots lighting bolts down to the earth. +A Samoyed and a Golden Retriever dog are playfully romping through a futuristic neon city at night. The neon lights emitted from the nearby buildings glistens off of their fur. +The Glenfinnan Viaduct is a historic railway bridge in Scotland, UK, that crosses over the west highland line between the towns of Mallaig and Fort Wiliam. It is a stunning sight as a steam train leaves the bridge, traveling over the arch-covered viaduct. The landscape is dotted with lush greenery and rocky mountains, creating a picturesque backdrop forthe train journey. The sky is blue and the sun is shining,making for a beautiful day to explore this majestic spot. +The camera directly faces colorful buildings in Burano ltaly. An adorable dalmation looks through a window on a building on the ground floor. Many people are walking and cycling along the canal streets in front of the buildings. +An adorable happy otter confidently stands on a surfboard wearing a yellow lifejacket, riding along turquoise tropical waters near lush tropical islands,3D digital render art style. +This close-up shot of a chameleon showcases its striking color changing capabilities.The background is blurred, drawing attention to the animals striking appearance. +A corgi vlogging itself in tropical Maui. +A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something.Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates awarm contrast, accentuating the cat's orange fur. The shot is clear and sharp, with a shallow depth of field. +Aerial view of Santorini during the blue hour, showcasing the stunning architecture of white Cycladic buildings with blue domes. The caldera views are breathtaking,and the lighting creates a beautiful, serene atmosphere. +Tiltshift of a construction site filled with workers, equipment, and heavy machinery. diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index d4d2d60938..a23890c730 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -1,17 +1,15 @@ import glob -import json import logging import os -from typing import Any, Dict, Optional +from typing import Optional -from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info +from opensora.acceleration.parallel_states import get_sequence_parallel_state from opensora.models.diffusion.common import PatchEmbed2D from opensora.models.diffusion.opensora.modules import Attention, BasicTransformerBlock, LayerNorm from opensora.npu_config import npu_config -from opensora.utils.utils import to_2tuple import mindspore as ms -from mindspore import mint, nn, ops +from mindspore import nn, ops from mindone.diffusers import __version__ from mindone.diffusers.configuration_utils import ConfigMixin, register_to_config @@ -349,8 +347,6 @@ def construct( encoder_hidden_states: Optional[ms.Tensor] = None, attention_mask: Optional[ms.Tensor] = None, encoder_attention_mask: Optional[ms.Tensor] = None, - return_dict: bool = True, - **kwargs, ): batch_size, c, frame, h, w = hidden_states.shape # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. diff --git a/examples/opensora_pku/opensora/sample/pipeline_opensora.py b/examples/opensora_pku/opensora/sample/pipeline_opensora.py index a642808dca..372c9f027c 100644 --- a/examples/opensora_pku/opensora/sample/pipeline_opensora.py +++ b/examples/opensora_pku/opensora/sample/pipeline_opensora.py @@ -791,21 +791,21 @@ def __call__( # return video def decode_latents_per_sample(self, latents): - print( - f"before vae decode {latents.shape}", - latents.max().item(), - latents.min().item(), - latents.mean().item(), - latents.std().item(), - ) + # print( + # f"before vae decode {latents.shape}", + # latents.max().item(), + # latents.min().item(), + # latents.mean().item(), + # latents.std().item(), + # ) video = self.vae.decode(latents).to(ms.float32) # (b t c h w) - print( - f"after vae decode {video.shape}", - video.max().item(), - video.min().item(), - video.mean().item(), - video.std().item(), - ) + # print( + # f"after vae decode {video.shape}", + # video.max().item(), + # video.min().item(), + # video.mean().item(), + # video.std().item(), + # ) video = ops.clip_by_value((video / 2.0 + 0.5), clip_value_min=0.0, clip_value_max=1.0).permute(0, 1, 3, 4, 2) return video # b t h w c diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh index 0d2b17e057..3ebdb7be9c 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh @@ -1,4 +1,4 @@ -# The DiT model is trained arbitrarily on stride=32. +# The DiT model is trained arbitrarily on stride=32. # So keep the resolution of the inference a multiple of 32. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image). export DEVICE_ID=0 @@ -6,20 +6,22 @@ python opensora/sample/sample.py \ --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ --num_frames 93 \ - --height 352 \ + --height 640 \ --width 640 \ --text_encoder_name_1 google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ + --text_prompt examples/sora.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/prompt_list_0_93x640_mt5" \ + --save_img_path "./sample_videos/sora_93x640_mt5" \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ + --seed 1234 \ --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" \ - --mode 1 + --mode 1 \ + --precision bf16 \ From fc68b9130183402d5228670e81787017f38f422b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:13:44 +0800 Subject: [PATCH 017/133] update train scripts --- .../multi-devices/sample_t2v_29x480p_ddp.sh | 21 -------------- .../multi-devices/sample_t2v_29x480p_sp.sh | 23 --------------- .../multi-devices/sample_t2v_29x720p_ddp.sh | 21 -------------- .../multi-devices/sample_t2v_29x720p_sp.sh | 23 --------------- .../sample_t2v_93x640_ddp.sh} | 20 +++++++------ .../multi-devices/sample_t2v_93x640_sp.sh | 28 +++++++++++++++++++ .../single-device/sample_t2v_29x1280.sh | 25 ----------------- 7 files changed, 40 insertions(+), 121 deletions(-) delete mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x480p_ddp.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x480p_sp.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x720p_ddp.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x720p_sp.sh rename examples/opensora_pku/scripts/text_condition/{single-device/sample_t2v_29x480p.sh => multi-devices/sample_t2v_93x640_ddp.sh} (54%) create mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x480p_ddp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x480p_ddp.sh deleted file mode 100644 index 9b31f8287d..0000000000 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x480p_ddp.sh +++ /dev/null @@ -1,21 +0,0 @@ - -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./sample_videos/prompt_list_0_29x480p_ddp/parallel_logs/" \ - opensora/sample/sample_t2v.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.2.0/29x480p \ - --num_frames 29 \ - --height 480 \ - --width 640 \ - --cache_dir "./" \ - --text_encoder_name google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae\ - --save_img_path "./sample_videos/prompt_list_0_29x480p_ddp" \ - --fps 24 \ - --guidance_scale 7.5 \ - --num_sampling_steps 100 \ - --enable_tiling \ - --max_sequence_length 512 \ - --sample_method EulerAncestralDiscrete \ - --model_type "dit" \ - --use_parallel True \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x480p_sp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x480p_sp.sh deleted file mode 100644 index d4e6df1ee0..0000000000 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x480p_sp.sh +++ /dev/null @@ -1,23 +0,0 @@ - -export ASCEND_RT_VISIBLE_DEVICES=0,1 -msrun --bind_core=True --worker_num=2 --local_worker_num=2 --master_port=9000 --log_dir="./sample_videos/prompt_list_0_29x480p_sp/parallel_logs/" \ - opensora/sample/sample_t2v.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.2.0/29x480p \ - --num_frames 29 \ - --height 480 \ - --width 640 \ - --cache_dir "./" \ - --text_encoder_name google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae\ - --save_img_path "./sample_videos/prompt_list_0_29x480p_sp" \ - --fps 24 \ - --guidance_scale 7.5 \ - --num_sampling_steps 100 \ - --enable_tiling \ - --max_sequence_length 512 \ - --sample_method EulerAncestralDiscrete \ - --model_type "dit" \ - --use_parallel True \ - --sp_size 2 diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x720p_ddp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x720p_ddp.sh deleted file mode 100644 index 0c8845b0d5..0000000000 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x720p_ddp.sh +++ /dev/null @@ -1,21 +0,0 @@ - -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./sample_videos/prompt_list_0_29x720p_ddp/parallel_logs/" \ - opensora/sample/sample_t2v.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.2.0/29x720p \ - --num_frames 29 \ - --height 720 \ - --width 1280 \ - --cache_dir "./" \ - --text_encoder_name google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae\ - --save_img_path "./sample_videos/prompt_list_0_29x720p_ddp" \ - --fps 24 \ - --guidance_scale 7.5 \ - --num_sampling_steps 100 \ - --enable_tiling \ - --max_sequence_length 512 \ - --sample_method EulerAncestralDiscrete \ - --model_type "dit" \ - --use_parallel True \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x720p_sp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x720p_sp.sh deleted file mode 100644 index ed8151e76b..0000000000 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_29x720p_sp.sh +++ /dev/null @@ -1,23 +0,0 @@ - -export ASCEND_RT_VISIBLE_DEVICES=0,1 -msrun --bind_core=True --worker_num=2 --local_worker_num=2 --master_port=9000 --log_dir="./sample_videos/prompt_list_0_29x720p_sp/parallel_logs/" \ - opensora/sample/sample_t2v.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.2.0/29x720p \ - --num_frames 29 \ - --height 720 \ - --width 1280 \ - --cache_dir "./" \ - --text_encoder_name google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae\ - --save_img_path "./sample_videos/prompt_list_0_29x720p_sp" \ - --fps 24 \ - --guidance_scale 7.5 \ - --num_sampling_steps 100 \ - --enable_tiling \ - --max_sequence_length 512 \ - --sample_method EulerAncestralDiscrete \ - --model_type "dit" \ - --use_parallel True \ - --sp_size 2 diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x480p.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh similarity index 54% rename from examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x480p.sh rename to examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh index 9b061c59e9..4c9e5127e4 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x480p.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh @@ -1,22 +1,26 @@ -export DEVICE_ID=0 -python opensora/sample/sample_v1_3.py \ + +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./sample_videos/sora_93x640_mt5_ddp/parallel_logs/" \ + opensora/sample/sample.py \ --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ - --num_frames 29 \ - --height 480 \ + --num_frames 93 \ + --height 640 \ --width 640 \ --text_encoder_name_1 google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ + --text_prompt examples/sora.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/prompt_list_0_29x480p" \ - --fps 24 \ + --save_img_path "./sample_videos/sora_93x640_mt5" \ + --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ + --seed 1234 \ --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" \ - --mode 1 + --mode 1 \ + --precision bf16 \ + --use_parallel True \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh new file mode 100644 index 0000000000..063a4e9e71 --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh @@ -0,0 +1,28 @@ + +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./sample_videos/sora_93x640_mt5_sp/parallel_logs/" \ + opensora/sample/sample.py \ + --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ + --version v1_3 \ + --num_frames 93 \ + --height 640 \ + --width 640 \ + --text_encoder_name_1 google/mt5-xxl \ + --text_prompt examples/sora.txt \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --save_img_path "./sample_videos/sora_93x640_mt5" \ + --fps 18 \ + --guidance_scale 7.5 \ + --num_sampling_steps 100 \ + --enable_tiling \ + --max_sequence_length 512 \ + --sample_method EulerAncestralDiscrete \ + --seed 1234 \ + --num_samples_per_prompt 1 \ + --rescale_betas_zero_snr \ + --prediction_type "v_prediction" \ + --mode 1 \ + --precision bf16 \ + --use_parallel True \ + --sp_size 8 diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh deleted file mode 100644 index cf3373c426..0000000000 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_29x1280.sh +++ /dev/null @@ -1,25 +0,0 @@ -# The DiT model is trained arbitrarily on stride=32. -# So keep the resolution of the inference a multiple of 32. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image). - -export DEVICE_ID=0 -python opensora/sample/sample.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ - --version v1_3 \ - --num_frames 29 \ - --height 704 \ - --width 1280 \ - --text_encoder_name_1 google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ - --ae WFVAEModel_D8_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/prompt_list_0_29x1280" \ - --fps 24 \ - --guidance_scale 7.5 \ - --num_sampling_steps 100 \ - --enable_tiling \ - --max_sequence_length 512 \ - --sample_method EulerAncestralDiscrete \ - --num_samples_per_prompt 1 \ - --rescale_betas_zero_snr \ - --prediction_type "v_prediction" \ - --mode 1 \ No newline at end of file From fdf374d50e28bb813d61cfd311f2c760a21f4b4b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:20:36 +0800 Subject: [PATCH 018/133] fix error --- .../opensora/models/diffusion/opensora/modeling_opensora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index a23890c730..31edf2913a 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -347,6 +347,7 @@ def construct( encoder_hidden_states: Optional[ms.Tensor] = None, attention_mask: Optional[ms.Tensor] = None, encoder_attention_mask: Optional[ms.Tensor] = None, + **kwargs, ): batch_size, c, frame, h, w = hidden_states.shape # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. From 43815c1f3e7d1d52395632848751fa05cade0689 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 14:19:16 +0800 Subject: [PATCH 019/133] a valid file name --- .../models/diffusion/opensora/modeling_opensora.py | 1 + examples/opensora_pku/opensora/utils/sample_utils.py | 6 ++++-- examples/opensora_pku/opensora/utils/utils.py | 7 +++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 31edf2913a..c1481eafb7 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -340,6 +340,7 @@ def get_attention_mask(self, attention_mask): attention_mask = attention_mask.to(ms.bool_) # use bool for sdpa return attention_mask + @ms.jit # use graph mode def construct( self, hidden_states: ms.Tensor, diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index ec692f4eb7..f94950d809 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -18,7 +18,7 @@ from opensora.models.diffusion.opensora.modules import Attention, LayerNorm from opensora.sample.pipeline_opensora import OpenSoraPipeline from opensora.utils.message_utils import print_banner -from opensora.utils.utils import _check_cfgs_in_parser, get_precision +from opensora.utils.utils import _check_cfgs_in_parser, get_precision, remove_invalid_characters from PIL import Image from tqdm import tqdm from transformers import AutoTokenizer @@ -379,7 +379,9 @@ def run_model_and_save_samples( csv_file = {"path": [], "cap": []} for i in range(n): for i_video in range(args.num_videos_per_prompt): - csv_file["path"].append(f"{i_video}-{args.text_prompt[i].strip()[:100]}.{ext}") + csv_file["path"].append( + remove_invalid_characters(f"{i_video}-{args.text_prompt[i].strip()[:100]}.{ext}") + ) # a valid file name csv_file["cap"].append(args.text_prompt[i]) temp_dataset_csv = os.path.join(save_dir, "dataset.csv") pd.DataFrame.from_dict(csv_file).to_csv(temp_dataset_csv, index=False, columns=csv_file.keys()) diff --git a/examples/opensora_pku/opensora/utils/utils.py b/examples/opensora_pku/opensora/utils/utils.py index df9e0791b3..7ba97b7a4b 100644 --- a/examples/opensora_pku/opensora/utils/utils.py +++ b/examples/opensora_pku/opensora/utils/utils.py @@ -255,3 +255,10 @@ def _check_cfgs_in_parser(cfgs: dict, parser: argparse.ArgumentParser): for k in cfgs.keys(): if k not in actions_dest and k not in defaults_key: raise KeyError(f"{k} does not exist in ArgumentParser!") + + +def remove_invalid_characters(file_name): + file_name = file_name.replace(" ", "-") + valid_pattern = r"[^a-zA-Z0-9_.-]" + cleaned_file_name = re.sub(valid_pattern, "-", file_name) + return cleaned_file_name From 9c5c3cf3132ec3a97656db998c9b4150c91961ea Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:47:14 +0800 Subject: [PATCH 020/133] lazy_inline & update save video func --- .../models/diffusion/opensora/modules.py | 20 +- .../opensora/train/train_t2v_diffusers.py | 234 +++++++++++------- .../opensora/utils/sample_utils.py | 8 + .../multi-devices/sample_t2v_93x640_ddp.sh | 2 +- mindone/visualize/videos.py | 23 ++ 5 files changed, 182 insertions(+), 105 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index 8a0aaa8de1..ebb0274f48 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -2,7 +2,6 @@ import numbers from typing import Optional, Tuple -import numpy as np from opensora.acceleration.communications import AllToAll_SBH from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info from opensora.npu_config import npu_config @@ -13,7 +12,7 @@ from mindone.diffusers.models.attention import FeedForward from mindone.diffusers.models.attention_processor import Attention as Attention_ -from mindone.utils.version_control import check_valid_flash_attention, choose_flash_attention_dtype +from mindone.utils.version_control import check_valid_flash_attention from ..common import PositionGetter3D, RoPE3D @@ -269,7 +268,7 @@ def __call__( height: int = 16, width: int = 16, ) -> ms.Tensor: - residual = hidden_states + # residual = hidden_states if get_sequence_parallel_state(): sequence_length, batch_size, _ = ( @@ -378,6 +377,7 @@ def __call__( class BasicTransformerBlock(nn.Cell): + @ms.lazy_inline(policy="front") def __init__( self, dim: int, @@ -466,13 +466,13 @@ def __init__( def construct( self, hidden_states: ms.Tensor, - attention_mask: Optional[ms.Tensor] = None, - encoder_hidden_states: Optional[ms.Tensor] = None, - encoder_attention_mask: Optional[ms.Tensor] = None, - timestep: Optional[ms.Tensor] = None, - frame: int = None, - height: int = None, - width: int = None, + attention_mask: Optional[ms.Tensor], + encoder_hidden_states: Optional[ms.Tensor], + encoder_attention_mask: Optional[ms.Tensor], + timestep: Optional[ms.Tensor], + frame: int, + height: int, + width: int, ) -> ms.Tensor: # 0. Self-Attention if get_sequence_parallel_state(): diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 99e3c94ff2..efce155a0d 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -20,6 +20,7 @@ from opensora.models.causalvideovae import ae_channel_config, ae_stride_config, ae_wrapper from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate from opensora.models.diffusion import Diffusion_models +from opensora.models.diffusion.common import PatchEmbed2D from opensora.models.diffusion.opensora.modules import Attention, LayerNorm from opensora.models.diffusion.opensora.net_with_loss import DiffusionWithLoss, DiffusionWithLossEval from opensora.train.commons import create_loss_scaler, parse_args @@ -29,20 +30,17 @@ from opensora.utils.message_utils import print_banner from opensora.utils.ms_utils import init_env from opensora.utils.utils import get_precision -from opensora.models.diffusion.common import PatchEmbed2D from mindone.diffusers.models.activations import SiLU -from mindone.diffusers.schedulers import ( - DDIMScheduler, DDPMScheduler, PNDMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, - FlowMatchEulerDiscreteScheduler,#CogVideoXDDIMScheduler, -) +from mindone.diffusers.schedulers import FlowMatchEulerDiscreteScheduler # CogVideoXDDIMScheduler, +from mindone.diffusers.schedulers import DDPMScheduler from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallbackEpoch, StopAtStepCallback from mindone.trainers.checkpoint import resume_train_network from mindone.trainers.lr_schedule import create_scheduler from mindone.trainers.optim import create_optimizer from mindone.trainers.train_step import TrainOneStepWrapper from mindone.trainers.zero import prepare_train_network -from mindone.transformers import T5EncoderModel, MT5EncoderModel, CLIPTextModelWithProjection +from mindone.transformers import CLIPTextModelWithProjection, MT5EncoderModel, T5EncoderModel from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import str2bool from mindone.utils.logger import set_logger @@ -51,11 +49,6 @@ logger = logging.getLogger(__name__) -# @ms.jit_class -# class DDPMScheduler(DDPMScheduler_diffusers): -# pass - - def set_all_reduce_fusion( params, split_num: int = 7, @@ -77,6 +70,7 @@ def set_all_reduce_fusion( # Training Loop # ################################################################################# + def main(args): # 1. init save_src_strategy = args.use_parallel and args.parallel_mode != "data" @@ -100,7 +94,7 @@ def main(args): set_logger(name="", output_dir=args.output_dir, rank=rank_id, log_level=eval(args.log_level)) # 2. Init and load models - ## Load VAE + # Load VAE train_with_vae_latent = args.vae_latent_folder is not None and len(args.vae_latent_folder) > 0 if train_with_vae_latent: assert os.path.exists( @@ -118,14 +112,16 @@ def main(args): } vae = ae_wrapper[args.ae](args.ae_path, **kwarg) # vae.vae_scale_factor = ae_stride_config[args.ae] - + if vae_dtype == ms.float16: custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] else: custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] - logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}") + logger.info( + f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}" + ) vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells) - + vae.set_train(False) for param in vae.get_parameters(): # freeze vae param.requires_grad = False @@ -151,7 +147,9 @@ def main(args): assert ( args.max_height % ae_stride_h == 0 ), f"Height must be divisible by ae_stride_h, but found Height ({args.max_height}), ae_stride_h ({ae_stride_h})." - assert (args.num_frames - 1) % ae_stride_t == 0, f"(Frames - 1) must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." + assert ( + args.num_frames - 1 + ) % ae_stride_t == 0, f"(Frames - 1) must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." assert ( args.max_width % ae_stride_h == 0 ), f"Width size must be divisible by ae_stride_h, but found Width ({args.max_width}), ae_stride_h ({ae_stride_h})." @@ -161,8 +159,7 @@ def main(args): vae.latent_size = latent_size = (args.max_height // ae_stride_h, args.max_width // ae_stride_w) args.latent_size_t = latent_size_t = (args.num_frames - 1) // ae_stride_t + 1 - - ## Load diffusion transformer + # Load diffusion transformer print_banner("Transformer model init") FA_dtype = get_precision(args.precision) if get_precision(args.precision) != ms.float32 else ms.bfloat16 model = Diffusion_models[args.model]( @@ -174,9 +171,9 @@ def main(args): interpolation_scale_h=args.interpolation_scale_h, interpolation_scale_w=args.interpolation_scale_w, interpolation_scale_t=args.interpolation_scale_t, - sparse1d=args.sparse1d, - sparse_n=args.sparse_n, - skip_connection=args.skip_connection, + sparse1d=args.sparse1d, + sparse_n=args.sparse_n, + skip_connection=args.skip_connection, use_recompute=args.gradient_checkpointing, num_no_recompute=args.num_no_recompute, FA_dtype=FA_dtype, @@ -222,61 +219,59 @@ def main(args): logger.info("Use random initialization for transformer") model.set_train(True) - ## Load text encoder + # Load text encoder if not args.text_embed_cache: print_banner("text encoder init") text_encoder_dtype = get_precision(args.text_encoder_precision) - if 'mt5' in args.text_encoder_name_1: + if "mt5" in args.text_encoder_name_1: text_encoder_1, loading_info = MT5EncoderModel.from_pretrained( - args.text_encoder_name_1, - cache_dir=args.cache_dir, + args.text_encoder_name_1, + cache_dir=args.cache_dir, output_loading_info=True, mindspore_dtype=text_encoder_dtype, - use_safetensors=True - ) + use_safetensors=True, + ) loading_info.pop("unexpected_keys") # decoder weights are ignored logger.info(f"Loaded MT5 Encoder: {loading_info}") - text_encoder_1 = text_encoder_1.set_train(False) + text_encoder_1 = text_encoder_1.set_train(False) else: text_encoder_1 = T5EncoderModel.from_pretrained( - args.text_encoder_name_1, cache_dir=args.cache_dir, - mindspore_dtype=text_encoder_dtype - ).set_train(False) + args.text_encoder_name_1, cache_dir=args.cache_dir, mindspore_dtype=text_encoder_dtype + ).set_train(False) text_encoder_2 = None if args.text_encoder_name_2 is not None: text_encoder_2, loading_info = CLIPTextModelWithProjection.from_pretrained( - args.text_encoder_name_2, - cache_dir=args.cache_dir, + args.text_encoder_name_2, + cache_dir=args.cache_dir, mindspore_dtype=text_encoder_dtype, output_loading_info=True, use_safetensors=True, - ) + ) loading_info.pop("unexpected_keys") # only load text model, ignore vision model # loading_info.pop("mising_keys") # Note: missed keys when loading open-clip models - logger.info(f"Loaded CLIP Encoder: {loading_info}") + logger.info(f"Loaded CLIP Encoder: {loading_info}") text_encoder_2 = text_encoder_2.set_train(False) else: text_encoder_1 = None text_encoder_2 = None text_encoder_dtype = None - kwargs = dict( - prediction_type=args.prediction_type, - rescale_betas_zero_snr=args.rescale_betas_zero_snr - ) + kwargs = dict(prediction_type=args.prediction_type, rescale_betas_zero_snr=args.rescale_betas_zero_snr) if args.cogvideox_scheduler: + from mindone.diffusers import CogVideoXDDIMScheduler + noise_scheduler = CogVideoXDDIMScheduler(**kwargs) elif args.v1_5_scheduler: - kwargs['beta_start'] = 0.00085 - kwargs['beta_end'] = 0.0120 - kwargs['beta_schedule'] = "scaled_linear" + kwargs["beta_start"] = 0.00085 + kwargs["beta_end"] = 0.0120 + kwargs["beta_schedule"] = "scaled_linear" noise_scheduler = DDPMScheduler(**kwargs) elif args.rf_scheduler: noise_scheduler = FlowMatchEulerDiscreteScheduler() - noise_scheduler_copy = copy.deepcopy(noise_scheduler) + # noise_scheduler_copy = copy.deepcopy(noise_scheduler) else: noise_scheduler = DDPMScheduler(**kwargs) - + # Get the target for loss depending on the prediction type if args.prediction_type is not None: # set prediction_type of scheduler if defined @@ -309,7 +304,7 @@ def main(args): # TODO: replace it with new dataset assert args.dataset == "t2v", "Support t2v dataset only." print_banner("Training dataset Loading...") - + # Setup data: # TODO: to use in v1.3 if args.trained_data_global_step is not None: @@ -325,13 +320,13 @@ def main(args): args.train_batch_size, world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), lengths=train_dataset.lengths, - group_frame=args.group_frame, #v1.2 - group_resolution=args.group_resolution, #v1.2 - initial_global_step_for_sampler = initial_global_step_for_sampler, #TODO: use in v1.3 - group_data=args.group_data #TODO: use in v1.3 + group_frame=args.group_frame, # v1.2 + group_resolution=args.group_resolution, # v1.2 + initial_global_step_for_sampler=initial_global_step_for_sampler, # TODO: use in v1.3 + group_data=args.group_data, # TODO: use in v1.3 ) - if (args.group_frame or args.group_resolution) #v1.2 - else None #v1.2 + if (args.group_frame or args.group_resolution) # v1.2 + else None # v1.2 ) collate_fn = Collate( args.train_batch_size, @@ -762,7 +757,7 @@ def main(args): def parse_t2v_train_args(parser): - ######## TODO: NEW in v1.3 , but may not use ### + # TODO: NEW in v1.3 , but may not use # dataset & dataloader parser.add_argument("--max_hxw", type=int, default=None) parser.add_argument("--min_hxw", type=int, default=None) @@ -774,61 +769,112 @@ def parse_t2v_train_args(parser): parser.add_argument("--use_decord", action="store_true") # text encoder & vae & diffusion model - parser.add_argument('--vae_fp32', action='store_true') - parser.add_argument('--extra_save_mem', action='store_true') - parser.add_argument("--text_encoder_name_1", type=str, default='DeepFloyd/t5-v1_1-xxl') + parser.add_argument("--vae_fp32", action="store_true") + parser.add_argument("--extra_save_mem", action="store_true") + parser.add_argument("--text_encoder_name_1", type=str, default="DeepFloyd/t5-v1_1-xxl") parser.add_argument("--text_encoder_name_2", type=str, default=None) - parser.add_argument('--sparse1d', action='store_true') - parser.add_argument('--sparse_n', type=int, default=2) - parser.add_argument('--skip_connection', action='store_true') - parser.add_argument('--cogvideox_scheduler', action='store_true') - parser.add_argument('--v1_5_scheduler', action='store_true') - parser.add_argument('--rf_scheduler', action='store_true') - parser.add_argument("--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"]) - parser.add_argument("--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme.") - parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.") - parser.add_argument("--mode_scale", type=float, default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.") + parser.add_argument("--sparse1d", action="store_true") + parser.add_argument("--sparse_n", type=int, default=2) + parser.add_argument("--skip_connection", action="store_true") + parser.add_argument("--cogvideox_scheduler", action="store_true") + parser.add_argument("--v1_5_scheduler", action="store_true") + parser.add_argument("--rf_scheduler", action="store_true") + parser.add_argument( + "--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"] + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) # diffusion setting parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.") parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.") - parser.add_argument('--rescale_betas_zero_snr', action='store_true') + parser.add_argument("--rescale_betas_zero_snr", action="store_true") # validation & logs parser.add_argument("--enable_profiling", action="store_true") parser.add_argument("--num_sampling_steps", type=int, default=20) - parser.add_argument('--guidance_scale', type=float, default=4.5) - parser.add_argument("--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store.")) + parser.add_argument("--guidance_scale", type=float, default=4.5) + parser.add_argument( + "--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store.") + ) # optimizer & scheduler - parser.add_argument("--optimizer", type=str, default="adamW", help='The optimizer type to use. Choose between ["AdamW", "prodigy"]') - parser.add_argument("--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.") - parser.add_argument("--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW") - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers.") + parser.add_argument( + "--optimizer", type=str, default="adamW", help='The optimizer type to use. Choose between ["AdamW", "prodigy"]' + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-02, help="Weight decay to use for unet params") - parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder") - parser.add_argument("--adam_epsilon", type=float, default=1e-15, help="Epsilon value for the Adam optimizer and Prodigy optimizers.") - parser.add_argument("--prodigy_use_bias_correction", type=bool, default=True, help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW") - parser.add_argument("--prodigy_safeguard_warmup", type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. Ignored if optimizer is adamW") - parser.add_argument("--prodigy_beta3", type=float, default=None, - help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " - "uses the value of square root of beta2. Ignored if optimizer is adamW", - ) - parser.add_argument("--allow_tf32", action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder" + ) + parser.add_argument( + "--adam_epsilon", type=float, default=1e-15, help="Epsilon value for the Adam optimizer and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") ######################## diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index f94950d809..418ea86806 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -314,6 +314,14 @@ def prepare_pipeline(args): [ f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}", f"Jit level: {args.jit_level}", + f"Distributed mode: {args.use_parallel}" + + ( + f"\nParallel mode: {args.parallel_mode}" + + (f"{args.zero_stage}" if args.parallel_mode == "zero" else "") + if args.use_parallel + else "" + ) + + (f"\nsp_size: {args.sp_size}" if args.sp_size != 1 else ""), f"Num of samples: {len(args.text_prompt)}", f"Num params: {num_params:,} (latte: {num_params_latte:,}, vae: {num_params_vae:,})", f"Num trainable params: {num_params_trainable:,}", diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh index 4c9e5127e4..36dc9d73c4 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh @@ -1,4 +1,4 @@ - +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./sample_videos/sora_93x640_mt5_ddp/parallel_logs/" \ opensora/sample/sample.py \ --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ diff --git a/mindone/visualize/videos.py b/mindone/visualize/videos.py index 95bd126c6a..28009ab01e 100644 --- a/mindone/visualize/videos.py +++ b/mindone/visualize/videos.py @@ -3,6 +3,7 @@ from typing import Union import av +import cv2 import imageio import numpy as np @@ -20,6 +21,13 @@ def create_video_from_rgb_numpy_arrays(image_arrays, output_file, fps: Union[int Credit to Perlexity """ + try: + save_video_file_using_av(image_arrays, output_file, fps) + except Exception: + save_video_file_using_cv2(image_arrays, output_file, fps) + + +def save_video_file_using_av(image_arrays, output_file, fps): # Get the dimensions of the first image height, width, _ = image_arrays[0].shape @@ -48,6 +56,21 @@ def create_video_from_rgb_numpy_arrays(image_arrays, output_file, fps: Union[int container.close() +def save_video_file_using_cv2(image_arrays, output_file, fps): + # Get the dimensions of the first image + height, width, _ = image_arrays[0].shape + # Define the codec and create a VideoWriter object + fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 + video_writer = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) + + # Write each frame to the video + for img in image_arrays: + video_writer.write(img) + + # Release the VideoWriter + video_writer.release() + + def create_video_from_numpy_frames(frames: np.ndarray, path: str, fps: Union[int, float] = 8, fmt="gif", loop=0): """ Args: From 312454664c7c4ac5b7cd4abf4d40fe9285585277 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:49:14 +0800 Subject: [PATCH 021/133] use cv2 to save videos --- examples/opensora_pku/examples/rec_video.py | 2 +- .../opensora_pku/examples/rec_video_folder.py | 2 +- .../causalvideovae/sample/rec_video_vae.py | 3 +- .../opensora/utils/sample_utils.py | 2 +- .../opensora/utils/video_utils.py | 82 +++++++++++++++++++ mindone/visualize/videos.py | 23 ------ 6 files changed, 86 insertions(+), 28 deletions(-) create mode 100644 examples/opensora_pku/opensora/utils/video_utils.py diff --git a/examples/opensora_pku/examples/rec_video.py b/examples/opensora_pku/examples/rec_video.py index 04fe503f1d..2e342c1210 100644 --- a/examples/opensora_pku/examples/rec_video.py +++ b/examples/opensora_pku/examples/rec_video.py @@ -13,7 +13,6 @@ mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) from mindone.utils.logger import set_logger -from mindone.visualize.videos import save_videos sys.path.append(".") from functools import partial @@ -24,6 +23,7 @@ from opensora.models.causalvideovae import ae_wrapper from opensora.npu_config import npu_config from opensora.utils.utils import get_precision +from opensora.utils.video_utils import save_videos logger = logging.getLogger(__name__) diff --git a/examples/opensora_pku/examples/rec_video_folder.py b/examples/opensora_pku/examples/rec_video_folder.py index a866277b40..28e9da1942 100644 --- a/examples/opensora_pku/examples/rec_video_folder.py +++ b/examples/opensora_pku/examples/rec_video_folder.py @@ -13,7 +13,6 @@ from mindone.utils.config import str2bool from mindone.utils.logger import set_logger -from mindone.visualize.videos import save_videos sys.path.append(".") from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info @@ -21,6 +20,7 @@ from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader from opensora.npu_config import npu_config from opensora.utils.utils import get_precision +from opensora.utils.video_utils import save_videos logger = logging.getLogger(__name__) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py b/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py index e18916b615..e4d91b24d2 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py @@ -10,8 +10,7 @@ from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader from opensora.utils.ms_utils import init_env from opensora.utils.utils import get_precision - -from mindone.visualize.videos import save_videos +from opensora.utils.video_utils import save_videos def main(args: argparse.Namespace): diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index 418ea86806..10c92a4aa1 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -19,6 +19,7 @@ from opensora.sample.pipeline_opensora import OpenSoraPipeline from opensora.utils.message_utils import print_banner from opensora.utils.utils import _check_cfgs_in_parser, get_precision, remove_invalid_characters +from opensora.utils.video_utils import save_videos from PIL import Image from tqdm import tqdm from transformers import AutoTokenizer @@ -45,7 +46,6 @@ from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import str2bool from mindone.utils.params import count_params -from mindone.visualize.videos import save_videos logger = logging.getLogger(__name__) diff --git a/examples/opensora_pku/opensora/utils/video_utils.py b/examples/opensora_pku/opensora/utils/video_utils.py new file mode 100644 index 0000000000..65b2edccfa --- /dev/null +++ b/examples/opensora_pku/opensora/utils/video_utils.py @@ -0,0 +1,82 @@ +import os +from typing import Union + +import cv2 +import imageio +import numpy as np + +__all__ = ["save_videos", "create_video_from_numpy_frames"] + + +def create_video_from_rgb_numpy_arrays(image_arrays, output_file, fps: Union[int, float] = 30): + """ + Creates an MP4 video file from a series of RGB NumPy array images using opencv. + + Parameters: + image_arrays (list): A list of RGB NumPy array images. + output_file (str): The path and filename of the output MP4 video file. + fps (int): The desired frames per second for the output video. Default is 30. + + Credit to Perlexity + """ + # Get the dimensions of the first image + height, width, _ = image_arrays[0].shape + # Define the codec and create a VideoWriter object + fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 + video_writer = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) + + # Write each frame to the video + for img in image_arrays: + video_writer.write(img) + + # Release the VideoWriter + video_writer.release() + + +def create_video_from_numpy_frames(frames: np.ndarray, path: str, fps: Union[int, float] = 8, fmt="gif", loop=0): + """ + Args: + frames: shape (f h w 3), range [0, 255], order rgb + """ + if fmt == "gif": + imageio.mimsave(path, frames, duration=1 / fps, loop=loop) + elif fmt == "png": + for i in range(len(frames)): + imageio.imwrite(path.replace(".png", f"-{i:04}.png"), frames[i]) + elif fmt == "mp4": + create_video_from_rgb_numpy_arrays(frames, path, fps=fps) + + +def save_videos(frames: np.ndarray, path: str, fps: Union[int, float] = 8, loop=0, concat=False): + """ + Save video frames to gif or mp4 files + Args: + frames: video frames in shape (b f h w 3), pixel value in [0, 1], RGB mode. + path: file path to save the output gif + fps: frames per sencond in the output gif. 1/fps = display duration per frame + concat: if True and b>1, all videos will be concatnated in grids and saved as one gif. + loop: number of loops to play. If 0, it will play endlessly. + """ + fmt = path.split(".")[-1] + assert fmt in ["gif", "mp4", "png"] + + # input frames: (b f H W 3), normalized to [0, 1] + frames = (frames * 255).round().clip(0, 255).astype(np.uint8) + os.makedirs(os.path.dirname(path), exist_ok=True) + + if len(frames.shape) == 4: + create_video_from_numpy_frames(frames, path, fps, fmt, loop) + else: + b, f, h, w, _ = frames.shape + if b > 1: + if concat: + canvas = np.array((f, h, w * b, 3), dtype=np.uint8) + for idx in range(b): + canvas[:, :, (w * idx) : (w * (idx + 1)), :] = frames[idx] + create_video_from_numpy_frames(canvas, path, fps, fmt, loop) + else: + for idx in range(b): + cur_path = path.replace(f".{fmt}", f"-{idx}.{fmt}") + create_video_from_numpy_frames(frames[idx], cur_path, fps, fmt, loop) + else: + create_video_from_numpy_frames(frames[0], path, fps, fmt, loop) diff --git a/mindone/visualize/videos.py b/mindone/visualize/videos.py index 28009ab01e..95bd126c6a 100644 --- a/mindone/visualize/videos.py +++ b/mindone/visualize/videos.py @@ -3,7 +3,6 @@ from typing import Union import av -import cv2 import imageio import numpy as np @@ -21,13 +20,6 @@ def create_video_from_rgb_numpy_arrays(image_arrays, output_file, fps: Union[int Credit to Perlexity """ - try: - save_video_file_using_av(image_arrays, output_file, fps) - except Exception: - save_video_file_using_cv2(image_arrays, output_file, fps) - - -def save_video_file_using_av(image_arrays, output_file, fps): # Get the dimensions of the first image height, width, _ = image_arrays[0].shape @@ -56,21 +48,6 @@ def save_video_file_using_av(image_arrays, output_file, fps): container.close() -def save_video_file_using_cv2(image_arrays, output_file, fps): - # Get the dimensions of the first image - height, width, _ = image_arrays[0].shape - # Define the codec and create a VideoWriter object - fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 - video_writer = cv2.VideoWriter(output_file, fourcc, fps, (width, height)) - - # Write each frame to the video - for img in image_arrays: - video_writer.write(img) - - # Release the VideoWriter - video_writer.release() - - def create_video_from_numpy_frames(frames: np.ndarray, path: str, fps: Union[int, float] = 8, fmt="gif", loop=0): """ Args: From eec25f19413b7c0e1829e0baf84d310da41169a2 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 16:04:50 +0800 Subject: [PATCH 022/133] update requirements --- examples/opensora_pku/requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/opensora_pku/requirements.txt b/examples/opensora_pku/requirements.txt index ca328d81ac..6860edc09c 100644 --- a/examples/opensora_pku/requirements.txt +++ b/examples/opensora_pku/requirements.txt @@ -9,13 +9,12 @@ imagesize toolz tqdm mindcv -decord safetensors omegaconf pyyaml sentencepiece mindnlp==0.4.0 transformers>=4.46.0 -pyav bs4 huggingface_hub>=0.22.2 +decord From b106f9f13e4be7bed8e34a59cff94ad683ed76c9 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:54:55 +0800 Subject: [PATCH 023/133] save path correction --- .../text_condition/multi-devices/sample_t2v_93x640_ddp.sh | 2 +- .../text_condition/multi-devices/sample_t2v_93x640_sp.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh index 36dc9d73c4..250f509792 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh @@ -10,7 +10,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --text_prompt examples/sora.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/sora_93x640_mt5" \ + --save_img_path "./sample_videos/sora_93x640_mt5_ddp" \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh index 063a4e9e71..7ecd0da215 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh @@ -11,7 +11,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --text_prompt examples/sora.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/sora_93x640_mt5" \ + --save_img_path "./sample_videos/sora_93x640_mt5_sp" \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ From 4260c228b46cc8d9e92e1abb4f5a13486e81553d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:47:20 +0800 Subject: [PATCH 024/133] revise generator with loss vae --- .../eval/eval_common_metrics.py | 262 ----- .../model/causal_vae/__init__.py | 36 - .../model/causal_vae/modeling_causalvae.py | 963 ------------------ .../model/losses/net_with_loss.py | 3 +- 4 files changed, 2 insertions(+), 1262 deletions(-) delete mode 100644 examples/opensora_pku/opensora/models/causalvideovae/eval/eval_common_metrics.py delete mode 100644 examples/opensora_pku/opensora/models/causalvideovae/model/causal_vae/__init__.py delete mode 100644 examples/opensora_pku/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/eval_common_metrics.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/eval_common_metrics.py deleted file mode 100644 index 0993a4a543..0000000000 --- a/examples/opensora_pku/opensora/models/causalvideovae/eval/eval_common_metrics.py +++ /dev/null @@ -1,262 +0,0 @@ -"""Calculates the CLIP Scores - -The CLIP model is a contrasitively learned language-image model. There is -an image encoder and a text encoder. It is believed that the CLIP model could -measure the similarity of cross modalities. Please find more information from -https://github.com/openai/CLIP. - -The CLIP Score measures the Cosine Similarity between two embedded features. -This repository utilizes the pretrained CLIP Model to calculate -the mean average of cosine similarities. - -See --help to see further details. - -Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP. - -Copyright 2023 The Hong Kong Polytechnic University - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import math -import os -import os.path as osp -import sys -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser - -import numpy as np -from decord import VideoReader - -mindone_lib_path = os.path.abspath("../../") -sys.path.insert(0, mindone_lib_path) -sys.path.append(".") -# from opensora.eval.cal_fvd import calculate_fvd -from opensora.eval.cal_lpips import calculate_lpips -from opensora.eval.cal_psnr import calculate_psnr - -try: - from opensora.eval.cal_flolpips import calculate_flolpips - - flolpips_isavailable = True -except Exception: - flolpips_isavailable = False -from opensora.eval.cal_ssim import calculate_ssim -from opensora.models.causalvideovae.model.dataset_videobase import create_dataloader -from opensora.utils.dataset_utils import create_video_transforms -from tqdm import tqdm - - -class VideoDataset: - def __init__( - self, - real_video_dir, - generated_video_dir, - num_frames, - sample_rate=1, - crop_size=None, - resolution=128, - output_columns=["real", "generated"], - ) -> None: - super().__init__() - self.real_video_files = self.combine_without_prefix(real_video_dir) - self.generated_video_files = self.combine_without_prefix(generated_video_dir) - assert ( - len(self.real_video_files) == len(self.generated_video_files) and len(self.real_video_files) > 0 - ), "Expect that the real and generated folders are not empty and contain the equal number of videos!" - self.num_frames = num_frames - self.sample_rate = sample_rate - self.crop_size = crop_size - self.short_size = resolution - self.output_columns = output_columns - - self.pixel_transforms = create_video_transforms( - size=self.short_size, - crop_size=crop_size, - random_crop=False, - disable_flip=True, - num_frames=num_frames, - backend="al", - ) - - def __len__(self): - return len(self.real_video_files) - - def __getitem__(self, index): - if index >= len(self): - raise IndexError - real_video_file = self.real_video_files[index] - generated_video_file = self.generated_video_files[index] - if os.path.basename(real_video_file).split(".")[0] != os.path.basename(generated_video_file).split(".")[0]: - print( - f"Warning! video file name mismatch! real and generated {os.path.basename(real_video_file)} and {os.path.basename(generated_video_file)}" - ) - real_video_tensor = self._load_video(real_video_file) - generated_video_tensor = self._load_video(generated_video_file) - return real_video_tensor.astype(np.float32), generated_video_tensor.astype(np.float32) - - def _load_video(self, video_path): - num_frames = self.num_frames - sample_rate = self.sample_rate - decord_vr = VideoReader( - video_path, - ) - total_frames = len(decord_vr) - sample_frames_len = sample_rate * num_frames - - if total_frames >= sample_frames_len: - s = 0 - e = s + sample_frames_len - num_frames = num_frames - else: - s = 0 - e = total_frames - num_frames = int(total_frames / sample_frames_len * num_frames) - print(f"Video total number of frames {total_frames} is less than the target num_frames {sample_frames_len}") - print(video_path) - - frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) - pixel_values = decord_vr.get_batch(frame_id_list).asnumpy() - # video_data = video_data.transpose(0, 3, 1, 2) # (T, H, W, C) -> (C, T, H, W) - # NOTE:it's to ensure augment all frames in a video in the same way. - # ref: https://albumentations.ai/docs/examples/example_multi_target/ - - inputs = {"image": pixel_values[0]} - for i in range(num_frames - 1): - inputs[f"image{i}"] = pixel_values[i + 1] - - output = self.pixel_transforms(**inputs) - - pixel_values = np.stack(list(output.values()), axis=0) - # (t h w c) -> (t c h w) - pixel_values = np.transpose(pixel_values, (0, 3, 1, 2)) - pixel_values = pixel_values / 255.0 - return pixel_values - - def combine_without_prefix(self, folder_path, prefix="."): - folder = [] - assert os.path.exists(folder_path), f"Expect that {folder_path} exist!" - for name in os.listdir(folder_path): - if name[0] == prefix: - continue - if osp.isfile(osp.join(folder_path, name)): - folder.append(osp.join(folder_path, name)) - folder.sort() - return folder - - -def calculate_common_metric(args, dataloader, dataset_size): - score_list = [] - index = 0 - for batch_data in tqdm( - dataloader, total=dataset_size - ): # {'real': real_video_tensor, 'generated':generated_video_tensor } - real_videos = batch_data["real"] - generated_videos = batch_data["generated"] - assert real_videos.shape[2] == generated_videos.shape[2] - if args.metric == "fvd": - if index == 0: - print("calculate fvd...") - raise ValueError - # tmp_list = list(calculate_fvd(real_videos, generated_videos, method=args.fvd_method)["value"].values()) - elif args.metric == "ssim": - if index == 0: - print("calculate ssim...") - tmp_list = list(calculate_ssim(real_videos, generated_videos)["value"].values()) - elif args.metric == "psnr": - if index == 0: - print("calculate psnr...") - tmp_list = list(calculate_psnr(real_videos, generated_videos)["value"].values()) - elif args.metric == "flolpips": - if flolpips_isavailable: - result = calculate_flolpips( - real_videos, - generated_videos, - ) - tmp_list = list(result["value"].values()) - else: - continue - else: - if index == 0: - print("calculate_lpips...") - tmp_list = list( - calculate_lpips( - real_videos, - generated_videos, - )["value"].values() - ) - index += 1 - score_list += tmp_list - return np.mean(score_list) - - -def main(): - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument("--batch_size", type=int, default=2, help="Batch size to use") - parser.add_argument("--real_video_dir", type=str, help=("the path of real videos`")) - parser.add_argument("--generated_video_dir", type=str, help=("the path of generated videos`")) - parser.add_argument("--device", type=str, default=None, help="Device to use. Like GPU or Ascend") - parser.add_argument( - "--num_workers", - type=int, - default=8, - help=("Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`"), - ) - parser.add_argument("--sample_fps", type=int, default=30) - parser.add_argument("--resolution", type=int, default=336) - parser.add_argument("--crop_size", type=int, default=None) - parser.add_argument("--num_frames", type=int, default=100) - parser.add_argument("--sample_rate", type=int, default=1) - parser.add_argument("--subset_size", type=int, default=None) - parser.add_argument("--metric", type=str, default="fvd", choices=["fvd", "psnr", "ssim", "lpips", "flolpips"]) - parser.add_argument("--fvd_method", type=str, default="styleganv", choices=["styleganv", "videogpt"]) - - args = parser.parse_args() - - if args.num_workers is None: - try: - num_cpus = len(os.sched_getaffinity(0)) - except AttributeError: - # os.sched_getaffinity is not available under Windows, use - # os.cpu_count instead (which may not return the *available* number - # of CPUs). - num_cpus = os.cpu_count() - - num_workers = min(num_cpus, 8) if num_cpus is not None else 0 - else: - num_workers = args.num_workers - - dataset = VideoDataset( - args.real_video_dir, - args.generated_video_dir, - num_frames=args.num_frames, - sample_rate=args.sample_rate, - crop_size=args.crop_size, - resolution=args.resolution, - ) - - dataloader = create_dataloader( - dataset, - batch_size=args.batch_size, - ds_name="video", - num_parallel_workers=num_workers, - shuffle=False, - drop_remainder=False, - ) - dataset_size = math.ceil(len(dataset) / float(args.batch_size)) - dataloader = dataloader.create_dict_iterator(1, output_numpy=True) - metric_score = calculate_common_metric(args, dataloader, dataset_size) - print("metric: ", args.metric, " ", metric_score) - - -if __name__ == "__main__": - main() diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/causal_vae/__init__.py b/examples/opensora_pku/opensora/models/causalvideovae/model/causal_vae/__init__.py deleted file mode 100644 index 1cedd72fc0..0000000000 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/causal_vae/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -import logging - -from mindspore import nn - -from .modeling_causalvae import CausalVAEModel - -logger = logging.getLogger(__name__) - - -class CausalVAEModelWrapper(nn.Cell): - def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs): - super(CausalVAEModelWrapper, self).__init__() - # if os.path.exists(ckpt): - # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) - self.vae, loading_info = CausalVAEModel.from_pretrained( - model_path, subfolder=subfolder, cache_dir=cache_dir, output_loading_info=True, **kwargs - ) - logger.info(loading_info) - if use_ema: - self.vae.init_from_ema(model_path) - self.vae = self.vae.ema - - def encode(self, x): # b c t h w - # x = self.vae.encode(x) - x = self.vae.encode(x) * 0.18215 - return x - - def decode(self, x): - # x = self.vae.decode(x) - x = self.vae.decode(x / 0.18215) - # b c t h w -> b t c h w - x = x.permute(0, 2, 1, 3, 4) - return x - - def dtype(self): - return self.vae.dtype diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py deleted file mode 100644 index 5eb6e0ce09..0000000000 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py +++ /dev/null @@ -1,963 +0,0 @@ -import logging -import os -from typing import Tuple - -from opensora.acceleration.parallel_states import get_sequence_parallel_state - -import mindspore as ms -from mindspore import nn, ops - -from mindone.diffusers import __version__ -from mindone.diffusers.models.modeling_utils import load_state_dict -from mindone.diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file - -from ..modeling_videobase import VideoBaseAE -from ..modules.conv import CausalConv3d -from ..modules.ops import nonlinearity -from ..utils.model_utils import resolve_str_to_obj - -logger = logging.getLogger(__name__) - - -class CausalVAEModel(VideoBaseAE): - """ - The default vales are set to be the same as those used in OpenSora v1.1 - """ - - def __init__( - self, - lr: float = 1e-5, # ignore - hidden_size: int = 128, - z_channels: int = 4, - hidden_size_mult: Tuple[int] = (1, 2, 4, 4), - attn_resolutions: Tuple[int] = [], - dropout: float = 0.0, - resolution: int = 256, - double_z: bool = True, - embed_dim: int = 4, - num_res_blocks: int = 2, - q_conv: str = "CausalConv3d", - encoder_conv_in: str = "Conv2d", - encoder_conv_out: str = "CausalConv3d", - encoder_attention: str = "AttnBlock3DFix", - encoder_resnet_blocks: Tuple[str] = ( - "ResnetBlock2D", - "ResnetBlock2D", - "ResnetBlock3D", - "ResnetBlock3D", - ), - encoder_spatial_downsample: Tuple[str] = ( - "Downsample", - "Downsample", - "Downsample", - "", - ), - encoder_temporal_downsample: Tuple[str] = ( - "", - "TimeDownsampleRes2x", - "TimeDownsampleRes2x", - "", - ), - encoder_mid_resnet: str = "ResnetBlock3D", - decoder_conv_in: str = "CausalConv3d", - decoder_conv_out: str = "CausalConv3d", - decoder_attention: str = "AttnBlock3DFix", - decoder_resnet_blocks: Tuple[str] = ( - "ResnetBlock3D", - "ResnetBlock3D", - "ResnetBlock3D", - "ResnetBlock3D", - ), - decoder_spatial_upsample: Tuple[str] = ( - "", - "SpatialUpsample2x", - "SpatialUpsample2x", - "SpatialUpsample2x", - ), - decoder_temporal_upsample: Tuple[str] = ("", "", "TimeUpsampleRes2x", "TimeUpsampleRes2x"), - decoder_mid_resnet: str = "ResnetBlock3D", - use_quant_layer: bool = True, - ckpt_path=None, - ignore_keys=[], - monitor=None, - use_fp16=False, - upcast_sigmoid=False, - use_recompute=False, - ): - super().__init__() - dtype = ms.float16 if use_fp16 else ms.float32 - - self.encoder = Encoder( - z_channels=z_channels, - hidden_size=hidden_size, - hidden_size_mult=hidden_size_mult, - attn_resolutions=attn_resolutions, - conv_in=encoder_conv_in, - conv_out=encoder_conv_out, - attention=encoder_attention, - resnet_blocks=encoder_resnet_blocks, - spatial_downsample=encoder_spatial_downsample, - temporal_downsample=encoder_temporal_downsample, - mid_resnet=encoder_mid_resnet, - dropout=dropout, - resolution=resolution, - num_res_blocks=num_res_blocks, - double_z=double_z, - dtype=dtype, - upcast_sigmoid=upcast_sigmoid, - ) - - self.decoder = Decoder( - z_channels=z_channels, - hidden_size=hidden_size, - hidden_size_mult=hidden_size_mult, - attn_resolutions=attn_resolutions, - conv_in=decoder_conv_in, - conv_out=decoder_conv_out, - attention=decoder_attention, - resnet_blocks=decoder_resnet_blocks, - spatial_upsample=decoder_spatial_upsample, - temporal_upsample=decoder_temporal_upsample, - mid_resnet=decoder_mid_resnet, - dropout=dropout, - resolution=resolution, - num_res_blocks=num_res_blocks, - dtype=dtype, - upcast_sigmoid=upcast_sigmoid, - ) - self.embed_dim = embed_dim - - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - self.split = ops.Split(axis=1, output_num=2) - self.concat = ops.Concat(axis=1) - self.exp = ops.Exp() - self.stdnormal = ops.StandardNormal() - self.depend = ops.Depend() if get_sequence_parallel_state() else None - - # self.encoder.recompute() - # self.decoder.recompute() - self.tile_sample_min_size = 256 - self.tile_sample_min_size_t = 33 - self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) - # t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] - # self.tile_latent_min_size_t = int((self.tile_sample_min_size_t - 1) / (2 ** len(t_down_ratio))) + 1 - self.tile_latent_min_size_t = 16 - self.tile_overlap_factor = 0.125 - self.use_tiling = False - self.use_quant_layer = use_quant_layer - if self.use_quant_layer: - quant_conv_cls = resolve_str_to_obj(q_conv) - self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) - self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) - if use_recompute: - self.recompute(self.encoder) - self.recompute(self.decoder) - if self.use_quant_layer: - self.recompute(self.quant_conv) - self.recompute(self.post_quant_conv) - - def recompute(self, b): - if not b._has_config_recompute: - b.recompute(parallel_optimizer_comm_recompute=True) - if isinstance(b, nn.CellList): - self.recompute(b[-1]) - elif ms.get_context("mode") == ms.GRAPH_MODE: - b.add_flags(output_no_recompute=True) - - def get_encoder(self): - if self.use_quant_layer: - return [self.quant_conv, self.encoder] - return [self.encoder] - - def get_decoder(self): - if self.use_quant_layer: - return [self.post_quant_conv, self.decoder] - return [self.decoder] - - # rewrite class method to allow the state dict as input - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - state_dict = kwargs.pop("state_dict", None) # additional key argument - cache_dir = kwargs.pop("cache_dir", None) - ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) - force_download = kwargs.pop("force_download", False) - from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - mindspore_dtype = kwargs.pop("mindspore_dtype", None) - subfolder = kwargs.pop("subfolder", None) - variant = kwargs.pop("variant", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path - - user_agent = { - "diffusers": __version__, - "file_type": "model", - "framework": "pytorch", - } - - # load config - config, unused_kwargs, commit_hash = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - return_commit_hash=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - **kwargs, - ) - - # load model - model_file = None - if from_flax: - raise NotImplementedError("loading flax checkpoint in mindspore model is not yet supported.") - else: - if state_dict is None: # edits: only search for model_file if state_dict is not provided - if use_safetensors: - try: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - except IOError as e: - if not allow_pickle: - raise e - pass - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - - model = cls.from_config(config, **unused_kwargs) - if state_dict is None: # edits: only load model_file if state_dict is None - state_dict = load_state_dict(model_file, variant=variant) - model._convert_deprecated_attention_blocks(state_dict) - - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) - - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } - - if mindspore_dtype is not None and not isinstance(mindspore_dtype, ms.Type): - raise ValueError( - f"{mindspore_dtype} needs to be of type `ms.Type`, e.g. `ms.float16`, but is {type(mindspore_dtype)}." - ) - elif mindspore_dtype is not None: - model = model.to(mindspore_dtype) - - model.register_to_config(_name_or_path=pretrained_model_name_or_path) - - # Set model in evaluation mode to deactivate DropOut modules by default - model.set_train(False) - if output_loading_info: - return model, loading_info - - return model - - def init_from_vae2d(self, path): - # default: tail init - # path: path to vae 2d model ckpt - vae2d_sd = ms.load_checkpoint(path) - vae_2d_keys = list(vae2d_sd.keys()) - vae_3d_keys = list(self.parameters_dict().keys()) - - # 3d -> 2d - map_dict = { - "conv.weight": "weight", - "conv.bias": "bias", - } - - new_state_dict = {} - for key_3d in vae_3d_keys: - if key_3d.startswith("loss"): - continue - - # param name mapping from vae-3d to vae-2d - key_2d = key_3d - for kw in map_dict: - key_2d = key_2d.replace(kw, map_dict[kw]) - - assert key_2d in vae_2d_keys, f"Key {key_2d} ({key_3d}) not found in 2D VAE" - - # set vae 3d state dict - shape_3d = self.parameters_dict()[key_3d].shape - shape_2d = vae2d_sd[key_2d].shape - if "bias" in key_2d: - assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" - new_state_dict[key_3d] = vae2d_sd[key_2d] - elif "norm" in key_2d: - assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" - new_state_dict[key_3d] = vae2d_sd[key_2d] - elif "conv" in key_2d or "nin_shortcut" in key_2d: - if shape_3d[:2] != shape_2d[:2]: - logger.info(key_2d, shape_3d, shape_2d) - w = vae2d_sd[key_2d] - new_w = ms.ops.zeros(shape_3d, dtype=w.dtype) - # tail initialization - new_w[:, :, -1, :, :] = w # cin, cout, t, h, w - - new_w = ms.Parameter(new_w, name=key_3d) - - new_state_dict[key_3d] = new_w - elif "attn_1" in key_2d: - new_val = vae2d_sd[key_2d].expand_dims(axis=2) - new_param = ms.Parameter(new_val, name=key_3d) - new_state_dict[key_3d] = new_param - else: - raise NotImplementedError(f"Key {key_3d} ({key_2d}) not implemented") - - m, u = ms.load_param_into_net(self, new_state_dict) - if len(m) > 0: - logger.info("net param not loaded: ", m) - if len(u) > 0: - logger.info("checkpoint param not loaded: ", u) - - def init_from_ckpt(self, path, ignore_keys=list()): - # TODO: support auto download pretrained checkpoints - sd = ms.load_checkpoint(path) - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - logger.info("Deleting key {} from state_dict.".format(k)) - del sd[k] - - if "ema_state_dict" in sd and len(sd["ema_state_dict"]) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0: - logger.info("Load from ema model!") - sd = sd["ema_state_dict"] - sd = {key.replace("module.", ""): value for key, value in sd.items()} - elif "state_dict" in sd: - logger.info("Load from normal model!") - if "gen_model" in sd["state_dict"]: - sd = sd["state_dict"]["gen_model"] - else: - sd = sd["state_dict"] - - ms.load_param_into_net(self, sd, strict_load=False) - logger.info(f"Restored from {path}") - - def _encode(self, x): - # return latent distribution, N(mean, logvar) - h = self.encoder(x) - if self.use_quant_layer: - h = self.quant_conv(h) - mean, logvar = self.split(h) - - return mean, logvar - - def sample(self, mean, logvar): - # sample z from latent distribution - logvar = ops.clip_by_value(logvar, -30.0, 20.0) - std = self.exp(0.5 * logvar) - z = mean + std * self.stdnormal(mean.shape) - - return z - - def encode(self, x): - if self.use_tiling and ( - x.shape[-1] > self.tile_sample_min_size - or x.shape[-2] > self.tile_sample_min_size - or x.shape[-3] > self.tile_sample_min_size_t - ): - posterior_mean, posterior_logvar = self.tiled_encode(x) - else: - # embedding, get latent representation z - posterior_mean, posterior_logvar = self._encode(x) - z = self.sample(posterior_mean, posterior_logvar) - - return z - - def tiled_encode2d(self, x): - overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) - row_limit = self.tile_latent_min_size - blend_extent - - # Split the image into 512x512 tiles and encode them separately. - rows = () - tile = None - for i in range(0, x.shape[3], overlap_size): - row = () - if self.depend is not None: - x = self.depend(x, tile) - for j in range(0, x.shape[4], overlap_size): - if self.depend is not None: - x = self.depend(x, tile) - tile = x[ - :, - :, - :, - i : i + self.tile_sample_min_size, - j : j + self.tile_sample_min_size, - ] - tile = self.encoder(tile) - if self.use_quant_layer: - tile = self.quant_conv(tile) - row += (tile,) - rows += (row,) - - result_rows = () - for i, row in enumerate(rows): - result_row = () - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row += (tile[:, :, :, :row_limit, :row_limit],) - result_rows += (ops.cat(result_row, axis=4),) - - moments = ops.cat(result_rows, axis=3) - return moments - - def tiled_encode(self, x): - t = x.shape[2] - t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)] - if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: - t_chunk_start_end = [[0, t]] - else: - t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] - if t_chunk_start_end[-1][-1] > t: - t_chunk_start_end[-1][-1] = t - elif t_chunk_start_end[-1][-1] < t: - last_start_end = [t_chunk_idx[-1], t] - t_chunk_start_end.append(last_start_end) - moments = [] - for idx, (start, end) in enumerate(t_chunk_start_end): - chunk_x = x[:, :, start:end] - if idx != 0: - moment = self.tiled_encode2d(chunk_x)[:, :, 1:] - else: - moment = self.tiled_encode2d(chunk_x) - moments.append(moment) - moments = ops.cat(moments, axis=2) - mean, logvar = self.split(moments) - return mean, logvar - - def decode(self, z): - if self.use_tiling and ( - z.shape[-1] > self.tile_latent_min_size - or z.shape[-2] > self.tile_latent_min_size - or z.shape[-3] > self.tile_latent_min_size_t - ): - return self.tiled_decode(z) - if self.use_quant_layer: - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def tiled_decode2d(self, z): - overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) - row_limit = self.tile_sample_min_size - blend_extent - - # Split z into overlapping 64x64 tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - rows = [] - for i in range(0, z.shape[3], overlap_size): - row = [] - for j in range(0, z.shape[4], overlap_size): - tile = z[ - :, - :, - :, - i : i + self.tile_latent_min_size, - j : j + self.tile_latent_min_size, - ] - if self.use_quant_layer: - tile = self.post_quant_conv(tile) - decoded = self.decoder(tile) - row.append(decoded) - rows.append(row) - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :, :row_limit, :row_limit]) - result_rows.append(ops.cat(result_row, axis=4)) - - dec = ops.cat(result_rows, axis=3) - return dec - - def tiled_decode(self, x): - t = x.shape[2] - t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)] - if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: - t_chunk_start_end = [[0, t]] - else: - t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] - if t_chunk_start_end[-1][-1] > t: - t_chunk_start_end[-1][-1] = t - elif t_chunk_start_end[-1][-1] < t: - last_start_end = [t_chunk_idx[-1], t] - t_chunk_start_end.append(last_start_end) - dec_ = [] - for idx, (start, end) in enumerate(t_chunk_start_end): - chunk_x = x[:, :, start:end] - if idx != 0: - dec = self.tiled_decode2d(chunk_x)[:, :, 1:] - else: - dec = self.tiled_decode2d(chunk_x) - dec_.append(dec) - dec_ = ops.cat(dec_, axis=2) - return dec_ - - def construct(self, input): - # overall pass, mostly for training - posterior_mean, posterior_logvar = self._encode(input) - z = self.sample(posterior_mean, posterior_logvar) - - recons = self.decode(z) - - return recons, posterior_mean, posterior_logvar - - def enable_tiling(self, use_tiling: bool = True): - self.use_tiling = use_tiling - - def disable_tiling(self): - self.enable_tiling(False) - - def blend_v(self, a: ms.Tensor, b: ms.Tensor, blend_extent: int) -> ms.Tensor: - blend_extent = min(a.shape[3], b.shape[3], blend_extent) - for y in range(blend_extent): - b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( - y / blend_extent - ) - return b - - def blend_h(self, a: ms.Tensor, b: ms.Tensor, blend_extent: int) -> ms.Tensor: - blend_extent = min(a.shape[4], b.shape[4], blend_extent) - for x in range(blend_extent): - b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( - x / blend_extent - ) - return b - - def validation_step(self, batch_idx): - raise NotImplementedError - - -class Encoder(nn.Cell): - """ - default value aligned to v1.1 vae config.json - """ - - def __init__( - self, - z_channels: int = 4, - hidden_size: int = 128, - hidden_size_mult: Tuple[int] = (1, 2, 4, 4), - attn_resolutions: Tuple[int] = (), - conv_in: str = "Conv2d", - conv_out: str = "CausalConv3d", - attention: str = "AttnBlock3D", # already fixed, same as AttnBlock3DFix - resnet_blocks: Tuple[str] = ( - "ResnetBlock2D", - "ResnetBlock2D", - "ResnetBlock3D", - "ResnetBlock3D", - ), - spatial_downsample: Tuple[str] = ( - "Downsample", - "Downsample", - "Downsample", - "", - ), - temporal_downsample: Tuple[str] = ( - "", - "TimeDownsampleRes2x", - "TimeDownsampleRes2x", - "", - ), - mid_resnet: str = "ResnetBlock3D", - dropout: float = 0.0, - resolution: int = 256, - num_res_blocks: int = 2, - double_z: bool = True, - upcast_sigmoid=False, - dtype=ms.float32, - **ignore_kwargs, - ): - """ - ch: hidden size, i.e. output channels of the first conv layer. typical: 128 - out_ch: placeholder, not used in Encoder - hidden_size_mult: channel multiply factors for each res block, also determine the number of res blocks. - Each block will be applied with spatial downsample x2 except for the last block. - In total, the spatial downsample rate = 2**(len(hidden_size_mult)-1) - resolution: spatial resolution, 256 - time_compress: the begging `time_compress` blocks will be applied with temporal downsample x2. - In total, the temporal downsample rate = 2**time_compress - """ - super().__init__() - assert len(resnet_blocks) == len(hidden_size_mult), print(hidden_size_mult, resnet_blocks) - self.num_resolutions = len(hidden_size_mult) - self.resolution = resolution - self.num_res_blocks = num_res_blocks - - self.dtype = dtype - self.upcast_sigmoid = (upcast_sigmoid,) - - # 1. Input conv - self.conv_in_name = conv_in - if conv_in == "Conv2d": - self.conv_in = nn.Conv2d(3, hidden_size, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True) - elif conv_in == "CausalConv3d": - self.conv_in = CausalConv3d( - 3, - hidden_size, - kernel_size=3, - stride=1, - padding=1, - ) - else: - raise NotImplementedError - - # 2. Downsample - curr_res = resolution - in_ch_mult = (1,) + tuple(hidden_size_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.CellList(auto_prefix=False) - self.downsample_flag = [0] * self.num_resolutions - self.time_downsample_flag = [0] * self.num_resolutions - for i_level in range(self.num_resolutions): - block = nn.CellList() - attn = nn.CellList() - block_in = hidden_size * in_ch_mult[i_level] # input channels - block_out = hidden_size * hidden_size_mult[i_level] # output channels - for i_block in range(self.num_res_blocks): - block.append( - resolve_str_to_obj(resnet_blocks[i_level])( - in_channels=block_in, - out_channels=block_out, - dropout=dropout, - dtype=self.dtype, - upcast_sigmoid=upcast_sigmoid, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(resolve_str_to_obj(attention)(block_in, dtype=self.dtype)) - - down = nn.Cell() - down.block = block - down.attn = attn - - # do spatial downsample according to config - if spatial_downsample[i_level]: - down.downsample = resolve_str_to_obj(spatial_downsample[i_level])(block_in, block_in, dtype=self.dtype) - curr_res = curr_res // 2 - self.downsample_flag[i_level] = 1 - else: - # TODO: still need it for 910b in new MS version? - down.downsample = nn.Identity() - - # do temporal downsample according to config - if temporal_downsample[i_level]: - # TODO: add dtype support? - down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])(block_in, block_in) - self.time_downsample_flag[i_level] = 1 - else: - # TODO: still need it for 910b in new MS version? - down.time_downsample = nn.Identity() - - down.update_parameters_name(prefix=self.param_prefix + f"down.{i_level}.") - self.down.append(down) - - # middle - self.mid = nn.Cell() - self.mid.block_1 = resolve_str_to_obj(mid_resnet)( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - dtype=self.dtype, - upcast_sigmoid=upcast_sigmoid, - ) - self.mid.attn_1 = resolve_str_to_obj(attention)(block_in, dtype=self.dtype) - self.mid.block_2 = resolve_str_to_obj(mid_resnet)( - in_channels=block_in, - out_channels=block_in, - dropout=dropout, - dtype=self.dtype, - upcast_sigmoid=upcast_sigmoid, - ) - self.mid.update_parameters_name(prefix=self.param_prefix + "mid.") - - # end - self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) - # self.norm_out = Normalize(block_in, extend=True) - - assert conv_out == "CausalConv3d", "Only CausalConv3d is supported for conv_out" - self.conv_out = resolve_str_to_obj(conv_out)( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) - - # copied from models.causalvideovae.model.modules.conv - def rearrange_in(self, x): - # b c f h w -> b f c h w - B, C, F, H, W = x.shape - x = ops.transpose(x, (0, 2, 1, 3, 4)) - # -> (b*f c h w) - x = ops.reshape(x, (-1, C, H, W)) - - return x - - # copied from models.causalvideovae.model.modules.conv - def rearrange_out(self, x, F): - BF, D, H_, W_ = x.shape - # (b*f D h w) -> (b f D h w) - x = ops.reshape(x, (BF // F, F, D, H_, W_)) - # -> (b D f h w) - x = ops.transpose(x, (0, 2, 1, 3, 4)) - - return x - - def construct(self, x): - # downsampling - if self.conv_in_name != "Conv2d": - hs = self.conv_in(x) - else: - F = x.shape[-3] - x = self.rearrange_in(x) - x = self.conv_in(x) - hs = self.rearrange_out(x, F) - - h = hs - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - # import pdb; pdb.set_trace() - h = self.down[i_level].block[i_block](hs) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs = h - # if hasattr(self.down[i_level], "downsample"): - # if not isinstance(self.down[i_level].downsample, nn.Identity): - if self.downsample_flag[i_level]: - hs = self.down[i_level].downsample(hs) - # if hasattr(self.down[i_level], "time_downsample"): - # if not isinstance(self.down[i_level].time_downsample, nn.Identity): - if self.time_downsample_flag[i_level]: - hs_down = self.down[i_level].time_downsample(hs) - hs = hs_down - - # middle - # h = hs[-1] - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h, upcast=self.upcast_sigmoid) - h = self.conv_out(h) - return h - - -class Decoder(nn.Cell): - """ - default value aligned to v1.1 vae config.json - """ - - def __init__( - self, - z_channels: int = 4, - hidden_size: int = 128, - hidden_size_mult: Tuple[int] = (1, 2, 4, 4), - attn_resolutions: Tuple[int] = (), - conv_in: str = "CausalConv3d", - conv_out: str = "CausalConv3d", - attention: str = "AttnBlock3D", # already fixed, same as AttnBlock3DFix - resnet_blocks: Tuple[str] = ( - "ResnetBlock3D", - "ResnetBlock3D", - "ResnetBlock3D", - "ResnetBlock3D", - ), - spatial_upsample: Tuple[str] = ("", "SpatialUpsample2x", "SpatialUpsample2x", "SpatialUpsample2x"), - temporal_upsample: Tuple[str] = ("", "", "TimeUpsampleRes2x", "TimeUpsampleRes2x"), - mid_resnet: str = "ResnetBlock3D", - dropout: float = 0.0, - resolution: int = 256, - num_res_blocks: int = 2, - double_z: bool = True, - upcast_sigmoid=False, - dtype=ms.float32, - **ignore_kwargs, - ): - super().__init__() - - self.num_resolutions = len(hidden_size_mult) - self.resolution = resolution - self.num_res_blocks = num_res_blocks - - self.dtype = dtype - self.upcast_sigmoid = upcast_sigmoid - - # 1. decode input z conv - # compute in_ch_mult, block_in and curr_res at lowest res - block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - # self.z_shape = (1, z_channels, curr_res, curr_res) - # logger.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - assert conv_in == "CausalConv3d", "Only CausalConv3d is supported for conv_in in Decoder currently" - self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, padding=1) - - # 2. middle - self.mid = nn.Cell() - self.mid.block_1 = resolve_str_to_obj(mid_resnet)( - in_channels=block_in, out_channels=block_in, dropout=dropout, dtype=self.dtype - ) - self.mid.attn_1 = resolve_str_to_obj(attention)(block_in, dtype=self.dtype) - self.mid.block_2 = resolve_str_to_obj(mid_resnet)( - in_channels=block_in, out_channels=block_in, dropout=dropout, dtype=self.dtype - ) - self.mid.update_parameters_name(prefix=self.param_prefix + "mid.") - - # 3. upsampling - self.up = nn.CellList(auto_prefix=False) - self.upsample_flag = [0] * self.num_resolutions - self.time_upsample_flag = [0] * self.num_resolutions - # i_level: 3 -> 2 -> 1 -> 0 - for i_level in reversed(range(self.num_resolutions)): - block = nn.CellList() - attn = nn.CellList() - block_out = hidden_size * hidden_size_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - resolve_str_to_obj(resnet_blocks[i_level])( - in_channels=block_in, - out_channels=block_out, - dropout=dropout, - dtype=self.dtype, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(resolve_str_to_obj(attention)(block_in, dtype=self.dtype)) - up = nn.Cell() - up.block = block - up.attn = attn - # do spatial upsample x2 except for the first block - if spatial_upsample[i_level]: - up.upsample = resolve_str_to_obj(spatial_upsample[i_level])(block_in, block_in, dtype=self.dtype) - curr_res = curr_res * 2 - self.upsample_flag[i_level] = 1 - else: - up.upsample = nn.Identity() - # do temporal upsample x2 in the bottom tc blocks - if temporal_upsample[i_level]: - # TODO: support dtype? - up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])(block_in, block_in) - self.time_upsample_flag[i_level] = 1 - else: - up.time_upsample = nn.Identity() - - up.update_parameters_name(prefix=self.param_prefix + f"up.{i_level}.") - if len(self.up) != 0: - self.up.insert(0, up) - else: - self.up.append(up) - - # end - self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) - # self.norm_out = Normalize(block_in, extend=True) - - assert conv_out == "CausalConv3d", "Only CausalConv3d is supported for conv_out in Decoder currently" - self.conv_out = CausalConv3d(block_in, 3, kernel_size=3, padding=1) - - def construct(self, z): - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - - # upsampling - i_level = self.num_resolutions - while i_level > 0: - i_level -= 1 - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - # if hasattr(self.up[i_level], 'upsample'): - # if not isinstance(self.up[i_level].upsample, nn.Identity): - if self.upsample_flag[i_level]: - h = self.up[i_level].upsample(h) - - # if hasattr(self.up[i_level], 'time_upsample'): - # if not isinstance(self.up[i_level].time_upsample, nn.Identity): - if self.time_upsample_flag[i_level]: - h = self.up[i_level].time_upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h, upcast=self.upcast_sigmoid) - h = self.conv_out(h) - return h diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py index 3c0d67d400..910a7235b3 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py @@ -122,10 +122,11 @@ def loss_function( else: mean_weighted_nll_loss = nll_loss.sum() / nll_loss.shape[0] # mean_nll_loss = mean_weighted_nll_loss + nll_loss = nll_loss.sum() / nll_loss.shape[0] # 2.3 kl loss kl_loss = self.kl(mean, logvar) - kl_loss = kl_loss.sum() / bs + kl_loss = kl_loss.sum() / kl_loss.shape[0] if wavelet_coeffs: wl_loss_l2 = mint.sum(l1(wavelet_coeffs[0], wavelet_coeffs[1])) / bs wl_loss_l3 = mint.sum(l1(wavelet_coeffs[2], wavelet_coeffs[3])) / bs From 6ce33e0ecc0f32147093abb7ce6dcbfe9fd2035f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:38:11 +0800 Subject: [PATCH 025/133] update vae training --- .../model/vae/modeling_wfvae.py | 58 ------------------- .../scripts/causalvae/train_with_gan_loss.sh | 3 +- .../train_with_gan_loss_multi_device.sh | 11 ++-- 3 files changed, 6 insertions(+), 66 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py index 32a1139bd9..7db1bc6705 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -716,64 +716,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return model - def init_from_vae2d(self, path): - # default: tail init - # path: path to vae 2d model ckpt - vae2d_sd = ms.load_checkpoint(path) - vae_2d_keys = list(vae2d_sd.keys()) - vae_3d_keys = list(self.parameters_dict().keys()) - - # 3d -> 2d - map_dict = { - "conv.weight": "weight", - "conv.bias": "bias", - } - - new_state_dict = {} - for key_3d in vae_3d_keys: - if key_3d.startswith("loss"): - continue - - # param name mapping from vae-3d to vae-2d - key_2d = key_3d - for kw in map_dict: - key_2d = key_2d.replace(kw, map_dict[kw]) - - assert key_2d in vae_2d_keys, f"Key {key_2d} ({key_3d}) should be in 2D VAE" - - # set vae 3d state dict - shape_3d = self.parameters_dict()[key_3d].shape - shape_2d = vae2d_sd[key_2d].shape - if "bias" in key_2d: - assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" - new_state_dict[key_3d] = vae2d_sd[key_2d] - elif "norm" in key_2d: - assert shape_3d == shape_2d, f"Shape mismatch for key {key_3d} ({key_2d})" - new_state_dict[key_3d] = vae2d_sd[key_2d] - elif "conv" in key_2d or "nin_shortcut" in key_2d: - if shape_3d[:2] != shape_2d[:2]: - logger.info(key_2d, shape_3d, shape_2d) - w = vae2d_sd[key_2d] - new_w = mint.zeros(shape_3d, dtype=w.dtype) - # tail initialization - new_w[:, :, -1, :, :] = w # cin, cout, t, h, w - - new_w = ms.Parameter(new_w, name=key_3d) - - new_state_dict[key_3d] = new_w - elif "attn_1" in key_2d: - new_val = vae2d_sd[key_2d].expand_dims(axis=2) - new_param = ms.Parameter(new_val, name=key_3d) - new_state_dict[key_3d] = new_param - else: - raise NotImplementedError(f"Key {key_3d} ({key_2d}) not implemented") - - m, u = ms.load_param_into_net(self, new_state_dict) - if len(m) > 0: - logger.info("net param not loaded: ", m) - if len(u) > 0: - logger.info("checkpoint param not loaded: ", u) - def init_from_ckpt(self, path, ignore_keys=list()): # TODO: support auto download pretrained checkpoints sd = ms.load_checkpoint(path) diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index 764e12b1f3..1beca45127 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -12,7 +12,6 @@ python opensora/train/train_causalvae.py \ --video_num_frames 25 \ --resolution 256 \ --dataloader_num_workers 8 \ - --load_from_checkpoint pretrained/causal_vae_488_init.ckpt \ --start_learning_rate 1e-5 \ --lr_scheduler constant \ --optim adamw \ @@ -20,7 +19,7 @@ python opensora/train/train_causalvae.py \ --clip_grad True \ --weight_decay 0.0 \ --mode 1 \ - --init_loss_scale 65536 \ + --init_loss_scale 1 \ --jit_level "O0" \ --use_discriminator True \ --use_ema True\ diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index 9adfcba876..259c2316d0 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -16,9 +16,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --data_file_path datasets/ucf101_train.csv \ --video_num_frames 25 \ --resolution 256 \ - --sample_rate 2 \ --dataloader_num_workers 8 \ - --load_from_checkpoint pretrained/causal_vae_488_init.ckpt \ --start_learning_rate 1e-5 \ --lr_scheduler constant \ --optim adamw \ @@ -26,15 +24,16 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --clip_grad True \ --weight_decay 0.0 \ --mode 1 \ - --init_loss_scale 65536 \ + --init_loss_scale 1 \ --jit_level "O0" \ --use_discriminator True \ --use_parallel True \ --use_ema True\ - --ema_start_step 0 \ --ema_decay 0.999 \ --perceptual_weight 1.0 \ --loss_type l1 \ + --sample_rate 1 \ --disc_cls causalvideovae.model.losses.LPIPSWithDiscriminator3D \ - --disc_start 2000 \ - --use_recompute True \ + --disc_start 0 \ + --wavelet_loss \ + --wavelet_weight 0.1 From 2bf6cdb4d293974dd06761eca0ba1fe5c35221de Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:06:14 +0800 Subject: [PATCH 026/133] support recompute --- .../causalvideovae/model/vae/modeling_wfvae.py | 15 +++++++++++++++ .../opensora/train/train_causalvae.py | 5 +++-- .../causalvae/train_with_gan_loss_multi_device.sh | 2 ++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py index 7db1bc6705..e323c10238 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -378,6 +378,7 @@ def __init__( l2_upsample_block: str = "Spatial2xTime2x3DUpsample", l2_upsample_wavelet: str = "InverseHaarWaveletTransform3D", dtype=ms.float32, + use_recompute=False, ) -> None: super().__init__() self.use_tiling = False @@ -429,6 +430,20 @@ def __init__( self.stdnormal = mint.normal self.update_parameters_name() # update parameter names to solve pname mismatch + if use_recompute: + self.recompute(self.encoder) + self.recompute(self.decoder) + if self.use_quant_layer: + self.recompute(self.quant_conv) + self.recompute(self.post_quant_conv) + + def recompute(self, b): + if not b._has_config_recompute: + b.recompute(parallel_optimizer_comm_recompute=True) + if isinstance(b, nn.CellList): + self.recompute(b[-1]) + elif ms.get_context("mode") == ms.GRAPH_MODE: + b.add_flags(output_no_recompute=True) def get_encoder(self): if self.use_quant_layer: diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index d967baf61c..4f2c1c1e7b 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -61,11 +61,12 @@ def main(args): low_cpu_mem_usage=False, device_map=None, dtype=dtype, + use_recompute=args.use_recompute, ) else: if rank_id == 0: logger.warning(f"Model will be initialized from config file {args.model_config}.") - ae = model_cls.from_config(args.model_config, dtype=dtype) + ae = model_cls.from_config(args.model_config, dtype=dtype, use_recompute=args.use_recompute) if args.load_from_checkpoint is not None: ae.init_from_ckpt(args.load_from_checkpoint) @@ -334,6 +335,7 @@ def main(args): f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}" + (f"\nJit level: {args.jit_level}" if args.mode == 0 else ""), f"Distributed mode: {args.use_parallel}", + f"Recompute: {args.use_recompute}", f"dtype: {args.precision}", f"Optimizer: {args.optim}", f"Use discriminator: {args.use_discriminator}", @@ -441,7 +443,6 @@ def main(args): f"Overflow occurs in step {cur_global_step} in discriminator" + (", drop update." if args.drop_overflow_update else ", still update.") ) - # log step_time = time.time() - start_time_s if step % args.log_interval == 0: diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index 259c2316d0..06804d4a42 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -7,6 +7,8 @@ exp_name="25x256x256" msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=$output_dir/$exp_name/parallel_logs opensora/train/train_causalvae.py \ --exp_name $exp_name \ + --model_name WFVAE \ + --model_config scripts/causalvae/wfvae_4dim.json \ --train_batch_size 1 \ --precision fp32 \ --max_steps 100000 \ From f343d908cc47d04379ed89dd38453f3a0086118b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:28:33 +0800 Subject: [PATCH 027/133] correct loss scaler print --- examples/opensora_pku/opensora/train/train_causalvae.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index 4f2c1c1e7b..8f2a949109 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -256,6 +256,7 @@ def main(args): lr=lr, ) loss_scaler_ae = create_loss_scaler(args) + scaling_sens = loss_scaler_ae.loss_scale_value if use_discriminator: optim_disc = create_optimizer( From 7ab7661c349e5fc7fc744e377f5a03eeb110da8d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:48:56 +0800 Subject: [PATCH 028/133] fix print loss scale --- examples/opensora_pku/opensora/train/train_causalvae.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index 8f2a949109..ae64beaaa0 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -429,9 +429,13 @@ def main(args): # NOTE: inputs must match the order in GeneratorWithLoss.construct loss_ae_t, overflow, scaling_sens = training_step_ae(x, global_step) + if isinstance(scaling_sens, ms.Parameter): + scaling_sens = scaling_sens.value() if global_step >= disc_start: loss_disc_t, overflow_d, scaling_sens_d = training_step_disc(x, global_step) + if isinstance(scaling_sens_d, ms.Parameter): + scaling_sens_d = scaling_sens_d.value() cur_global_step = epoch * dataset_size + step + 1 # starting from 1 for logging if overflow: From 58607a9074a3d89cbb42f9ba3eae03da513656ca Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 6 Nov 2024 11:25:56 +0800 Subject: [PATCH 029/133] update eval --- examples/opensora_pku/opensora/eval/eval_common_metrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/opensora_pku/opensora/eval/eval_common_metrics.py b/examples/opensora_pku/opensora/eval/eval_common_metrics.py index 39cf479bb7..d1999432c0 100644 --- a/examples/opensora_pku/opensora/eval/eval_common_metrics.py +++ b/examples/opensora_pku/opensora/eval/eval_common_metrics.py @@ -27,7 +27,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - import csv import math import os From c79c621e0e8d8e55e44fe962b842a3d41c46337f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 6 Nov 2024 11:32:19 +0800 Subject: [PATCH 030/133] copy to vae/eval --- .../models/causalvideovae/eval/eval.py | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/eval/eval.py b/examples/opensora_pku/opensora/models/causalvideovae/eval/eval.py index 0993a4a543..d1999432c0 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/eval/eval.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/eval/eval.py @@ -27,7 +27,7 @@ See the License for the specific language governing permissions and limitations under the License. """ - +import csv import math import os import os.path as osp @@ -62,13 +62,20 @@ def __init__( real_video_dir, generated_video_dir, num_frames, + real_data_file_path=None, sample_rate=1, crop_size=None, resolution=128, output_columns=["real", "generated"], ) -> None: super().__init__() - self.real_video_files = self.combine_without_prefix(real_video_dir) + if real_data_file_path is not None: + print(f"Loading videos from data file {real_data_file_path}") + self.parse_data_file(real_data_file_path) + self.read_from_data_file = True + else: + self.real_video_files = self.combine_without_prefix(real_video_dir) + self.read_from_data_file = False self.generated_video_files = self.combine_without_prefix(generated_video_dir) assert ( len(self.real_video_files) == len(self.generated_video_files) and len(self.real_video_files) > 0 @@ -78,6 +85,7 @@ def __init__( self.crop_size = crop_size self.short_size = resolution self.output_columns = output_columns + self.real_video_dir = real_video_dir self.pixel_transforms = create_video_transforms( size=self.short_size, @@ -94,7 +102,12 @@ def __len__(self): def __getitem__(self, index): if index >= len(self): raise IndexError - real_video_file = self.real_video_files[index] + if self.read_from_data_file: + video_dict = self.real_video_files[index] + video_fn = video_dict["video"] + real_video_file = os.path.join(self.real_video_dir, video_fn) + else: + real_video_file = self.real_video_files[index] generated_video_file = self.generated_video_files[index] if os.path.basename(real_video_file).split(".")[0] != os.path.basename(generated_video_file).split(".")[0]: print( @@ -104,6 +117,14 @@ def __getitem__(self, index): generated_video_tensor = self._load_video(generated_video_file) return real_video_tensor.astype(np.float32), generated_video_tensor.astype(np.float32) + def parse_data_file(self, data_file_path): + if data_file_path.endswith(".csv"): + with open(data_file_path, "r") as csvfile: + self.real_video_files = list(csv.DictReader(csvfile)) + else: + raise ValueError("Only support csv file now!") + self.real_video_files = sorted(self.real_video_files, key=lambda x: os.path.basename(x["video"])) + def _load_video(self, video_path): num_frames = self.num_frames sample_rate = self.sample_rate @@ -146,11 +167,11 @@ def combine_without_prefix(self, folder_path, prefix="."): folder = [] assert os.path.exists(folder_path), f"Expect that {folder_path} exist!" for name in os.listdir(folder_path): - if name[0] == prefix: + if name[0] == prefix or name.split(".")[1] == "txt": continue if osp.isfile(osp.join(folder_path, name)): folder.append(osp.join(folder_path, name)) - folder.sort() + folder = sorted(folder, key=lambda x: os.path.basename(x)) return folder @@ -203,6 +224,7 @@ def main(): parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument("--batch_size", type=int, default=2, help="Batch size to use") parser.add_argument("--real_video_dir", type=str, help=("the path of real videos`")) + parser.add_argument("--real_data_file_path", type=str, default=None, help=("the path of real videos csv file`")) parser.add_argument("--generated_video_dir", type=str, help=("the path of generated videos`")) parser.add_argument("--device", type=str, default=None, help="Device to use. Like GPU or Ascend") parser.add_argument( @@ -239,6 +261,7 @@ def main(): args.real_video_dir, args.generated_video_dir, num_frames=args.num_frames, + real_data_file_path=args.real_data_file_path, sample_rate=args.sample_rate, crop_size=args.crop_size, resolution=args.resolution, From 727ccd0de367763b70d200fe21d625db62a608d5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 6 Nov 2024 14:54:52 +0800 Subject: [PATCH 031/133] a new vae config file --- .../scripts/causalvae/train_with_gan_loss.sh | 2 +- .../train_with_gan_loss_multi_device.sh | 2 +- .../scripts/causalvae/wfvae_8dim.json | 23 +++++++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 examples/opensora_pku/scripts/causalvae/wfvae_8dim.json diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index 1beca45127..9ced6f905d 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -1,7 +1,7 @@ python opensora/train/train_causalvae.py \ --exp_name "25x256x256" \ --model_name WFVAE \ - --model_config scripts/causalvae/wfvae_4dim.json \ + --model_config scripts/causalvae/wfvae_8dim.json \ --train_batch_size 1 \ --precision fp32 \ --max_steps 100000 \ diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index 06804d4a42..7034543230 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -8,7 +8,7 @@ exp_name="25x256x256" msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=$output_dir/$exp_name/parallel_logs opensora/train/train_causalvae.py \ --exp_name $exp_name \ --model_name WFVAE \ - --model_config scripts/causalvae/wfvae_4dim.json \ + --model_config scripts/causalvae/wfvae_8dim.json \ --train_batch_size 1 \ --precision fp32 \ --max_steps 100000 \ diff --git a/examples/opensora_pku/scripts/causalvae/wfvae_8dim.json b/examples/opensora_pku/scripts/causalvae/wfvae_8dim.json new file mode 100644 index 0000000000..995d10e868 --- /dev/null +++ b/examples/opensora_pku/scripts/causalvae/wfvae_8dim.json @@ -0,0 +1,23 @@ +{ + "_class_name": "WFVAEModel", + "_diffusers_version": "0.30.2", + "base_channels": 128, + "connect_res_layer_num": 1, + "decoder_energy_flow_hidden_size": 128, + "decoder_num_resblocks": 2, + "dropout": 0.0, + "encoder_energy_flow_hidden_size": 64, + "encoder_num_resblocks": 2, + "l1_dowmsample_block": "Downsample", + "l1_downsample_wavelet": "HaarWaveletTransform2D", + "l1_upsample_block": "Upsample", + "l1_upsample_wavelet": "InverseHaarWaveletTransform2D", + "l2_dowmsample_block": "Spatial2xTime2x3DDownsample", + "l2_downsample_wavelet": "HaarWaveletTransform3D", + "l2_upsample_block": "Spatial2xTime2x3DUpsample", + "l2_upsample_wavelet": "InverseHaarWaveletTransform3D", + "latent_dim": 8, + "norm_type": "layernorm", + "t_interpolation": "trilinear", + "use_attention": true +} From 2ab6697c9e863b5e3f8ad0a2330f8d3121026a5d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 6 Nov 2024 14:56:10 +0800 Subject: [PATCH 032/133] remove prefix --- examples/opensora_pku/examples/rec_image.py | 2 +- examples/opensora_pku/examples/rec_video.py | 2 +- .../opensora/models/causalvideovae/__init__.py | 16 ++++++++++++---- .../opensora_pku/opensora/sample/rec_image.py | 2 +- .../opensora_pku/opensora/sample/rec_video.py | 2 +- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/opensora_pku/examples/rec_image.py b/examples/opensora_pku/examples/rec_image.py index e667b506d3..e8476fc315 100644 --- a/examples/opensora_pku/examples/rec_image.py +++ b/examples/opensora_pku/examples/rec_image.py @@ -69,7 +69,7 @@ def main(args): if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") state_dict = ms.load_checkpoint(args.ms_checkpoint) - # rm 'network.' prefix + state_dict = dict( [k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items() ) diff --git a/examples/opensora_pku/examples/rec_video.py b/examples/opensora_pku/examples/rec_video.py index 2e342c1210..355c0c1f43 100644 --- a/examples/opensora_pku/examples/rec_video.py +++ b/examples/opensora_pku/examples/rec_video.py @@ -107,7 +107,7 @@ def main(args): if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") state_dict = ms.load_checkpoint(args.ms_checkpoint) - # rm 'network.' prefix + state_dict = dict( [k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items() ) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/__init__.py b/examples/opensora_pku/opensora/models/causalvideovae/__init__.py index 4c968af7e9..bfe1f311c7 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/__init__.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/__init__.py @@ -1,4 +1,5 @@ import logging +import os import mindspore as ms from mindspore import nn @@ -38,11 +39,18 @@ def dtype(self): class WFVAEModelWrapper(nn.Cell): - def __init__(self, model_path, dtype=ms.float32, subfolder=None, cache_dir=None, **kwargs): + def __init__(self, model_path=None, dtype=ms.float32, subfolder=None, cache_dir=None, vae=None, **kwargs): super(WFVAEModelWrapper, self).__init__() - self.vae = WFVAEModel.from_pretrained( - model_path, subfolder=subfolder, cache_dir=cache_dir, dtype=dtype, **kwargs - ) + assert model_path is not None or vae is not None, "At least oen of [`model_path`, `vae`] should be provided." + + if vae is not None: + self.vae = vae + else: + assert model_path is not None, "When `vae` is not None, expect to get `model_path`!" + assert os.path.exists(model_path), f"`model_path` does not exist!: {model_path}" + self.vae = WFVAEModel.from_pretrained( + model_path, subfolder=subfolder, cache_dir=cache_dir, dtype=dtype, **kwargs + ) self.shift = ms.Tensor(self.vae.config.shift)[None, :, None, None, None] self.scale = ms.Tensor(self.vae.config.scale)[None, :, None, None, None] diff --git a/examples/opensora_pku/opensora/sample/rec_image.py b/examples/opensora_pku/opensora/sample/rec_image.py index 36191269e8..6d5984d962 100644 --- a/examples/opensora_pku/opensora/sample/rec_image.py +++ b/examples/opensora_pku/opensora/sample/rec_image.py @@ -69,7 +69,7 @@ def main(args): if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") state_dict = ms.load_checkpoint(args.ms_checkpoint) - # rm 'network.' prefix + state_dict = dict( [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() ) diff --git a/examples/opensora_pku/opensora/sample/rec_video.py b/examples/opensora_pku/opensora/sample/rec_video.py index f5463b4d9c..a6fbc6c244 100644 --- a/examples/opensora_pku/opensora/sample/rec_video.py +++ b/examples/opensora_pku/opensora/sample/rec_video.py @@ -119,7 +119,7 @@ def main(args): if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") state_dict = ms.load_checkpoint(args.ms_checkpoint) - # rm 'network.' prefix + state_dict = dict( [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() ) From 599e57e147ef34630c454a2fbf58530adf5ab26f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 6 Nov 2024 14:59:13 +0800 Subject: [PATCH 033/133] allow model config --- examples/opensora_pku/examples/rec_image.py | 22 +++++++++++++++++ examples/opensora_pku/examples/rec_video.py | 22 +++++++++++++++++ .../opensora_pku/examples/rec_video_folder.py | 24 ++++++++++++++++++- 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/examples/opensora_pku/examples/rec_image.py b/examples/opensora_pku/examples/rec_image.py index e8476fc315..5b2cb0110b 100644 --- a/examples/opensora_pku/examples/rec_image.py +++ b/examples/opensora_pku/examples/rec_image.py @@ -1,6 +1,7 @@ import argparse import logging import os +import re import sys import cv2 @@ -19,6 +20,7 @@ sys.path.append(".") from opensora.models.causalvideovae import ae_wrapper +from opensora.models.causalvideovae.model.registry import ModelRegistry from opensora.npu_config import npu_config from opensora.utils.utils import get_precision @@ -75,10 +77,27 @@ def main(args): ) else: state_dict = None + + if args.model_config is not None: + assert os.path.exists(args.model_config), f"`model_config` does not exist! {args.model_config}" + pattern = r"^([A-Za-z]+)Model" + if re.match(pattern, args.ae): + model_name = re.match(pattern, args.ae).group(1) + model_cls = ModelRegistry.get_model(model_name) + vae = model_cls.from_config(args.model_config, dtype=dtype) + if args.ms_checkpoint is None or not os.path.exists(args.ms_checkpoint): + logger.warning( + "VAE is randomly initialized. The inference results may be incorrect! Check `ms_checkpoint`!" + ) + + else: + logger.warning(f"Incorrect ae name, must be one of {ae_wrapper.keys()}") + vae = None kwarg = { "state_dict": state_dict, "use_safetensors": True, "dtype": dtype, + "vae": vae, } vae = ae_wrapper[args.ae](args.ae_path, **kwarg) @@ -160,5 +179,8 @@ def main(args): parser.add_argument( "--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax" ) + parser.add_argument( + "--model_config", type=str, default=None, help="The model config file for initiating vae model." + ) args = parser.parse_args() main(args) diff --git a/examples/opensora_pku/examples/rec_video.py b/examples/opensora_pku/examples/rec_video.py index 355c0c1f43..d7a3ce4d4b 100644 --- a/examples/opensora_pku/examples/rec_video.py +++ b/examples/opensora_pku/examples/rec_video.py @@ -2,6 +2,7 @@ import logging import os import random +import re import sys import numpy as np @@ -21,6 +22,7 @@ from albumentations import Compose, Lambda, Resize, ToFloat from opensora.dataset.transform import center_crop_th_tw from opensora.models.causalvideovae import ae_wrapper +from opensora.models.causalvideovae.model.registry import ModelRegistry from opensora.npu_config import npu_config from opensora.utils.utils import get_precision from opensora.utils.video_utils import save_videos @@ -113,10 +115,27 @@ def main(args): ) else: state_dict = None + + if args.model_config is not None: + assert os.path.exists(args.model_config), f"`model_config` does not exist! {args.model_config}" + pattern = r"^([A-Za-z]+)Model" + if re.match(pattern, args.ae): + model_name = re.match(pattern, args.ae).group(1) + model_cls = ModelRegistry.get_model(model_name) + vae = model_cls.from_config(args.model_config, dtype=dtype) + if args.ms_checkpoint is None or not os.path.exists(args.ms_checkpoint): + logger.warning( + "VAE is randomly initialized. The inference results may be incorrect! Check `ms_checkpoint`!" + ) + + else: + logger.warning(f"Incorrect ae name, must be one of {ae_wrapper.keys()}") + vae = None kwarg = { "state_dict": state_dict, "use_safetensors": True, "dtype": dtype, + "vae": vae, } vae = ae_wrapper[args.ae](args.ae_path, **kwarg) @@ -205,5 +224,8 @@ def main(args): parser.add_argument( "--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax" ) + parser.add_argument( + "--model_config", type=str, default=None, help="The model config file for initiating vae model." + ) args = parser.parse_args() main(args) diff --git a/examples/opensora_pku/examples/rec_video_folder.py b/examples/opensora_pku/examples/rec_video_folder.py index 28e9da1942..a90da80c84 100644 --- a/examples/opensora_pku/examples/rec_video_folder.py +++ b/examples/opensora_pku/examples/rec_video_folder.py @@ -1,6 +1,7 @@ import argparse import logging import os +import re import sys import numpy as np @@ -18,6 +19,7 @@ from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info from opensora.models.causalvideovae import ae_wrapper from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader +from opensora.models.causalvideovae.model.registry import ModelRegistry from opensora.npu_config import npu_config from opensora.utils.utils import get_precision from opensora.utils.video_utils import save_videos @@ -53,16 +55,33 @@ def main(args): if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") state_dict = ms.load_checkpoint(args.ms_checkpoint) - # rm 'network.' prefix + state_dict = dict( [k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items() ) else: state_dict = None + + if args.model_config is not None: + assert os.path.exists(args.model_config), f"`model_config` does not exist! {args.model_config}" + pattern = r"^([A-Za-z]+)Model" + if re.match(pattern, args.ae): + model_name = re.match(pattern, args.ae).group(1) + model_cls = ModelRegistry.get_model(model_name) + vae = model_cls.from_config(args.model_config, dtype=dtype) + if args.ms_checkpoint is None or not os.path.exists(args.ms_checkpoint): + logger.warning( + "VAE is randomly initialized. The inference results may be incorrect! Check `ms_checkpoint`!" + ) + + else: + logger.warning(f"Incorrect ae name, must be one of {ae_wrapper.keys()}") + vae = None kwarg = { "state_dict": state_dict, "use_safetensors": True, "dtype": dtype, + "vae": vae, } vae = ae_wrapper[args.ae](args.ae_path, **kwarg) @@ -235,5 +254,8 @@ def main(args): default="video", help="The column of video file path in `data_file_path`. Defaults to `video`.", ) + parser.add_argument( + "--model_config", type=str, default=None, help="The model config file for initiating vae model." + ) args = parser.parse_args() main(args) From fc080ba8b97a2c337c0c848965f44af1bbe776b9 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 6 Nov 2024 15:13:07 +0800 Subject: [PATCH 034/133] update 4dim json --- .../opensora_pku/scripts/causalvae/wfvae_4dim.json | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/opensora_pku/scripts/causalvae/wfvae_4dim.json b/examples/opensora_pku/scripts/causalvae/wfvae_4dim.json index 6509a76d4f..4209487208 100644 --- a/examples/opensora_pku/scripts/causalvae/wfvae_4dim.json +++ b/examples/opensora_pku/scripts/causalvae/wfvae_4dim.json @@ -19,5 +19,17 @@ "latent_dim": 4, "norm_type": "layernorm", "t_interpolation": "trilinear", - "use_attention": true + "use_attention": true, + "scale": [ + 0.18215, + 0.18215, + 0.18215, + 0.18215 + ], + "shift": [ + 0, + 0, + 0, + 0 + ] } From b13b35a58fb2487919ba524c02b8d82dcc844b2c Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 6 Nov 2024 15:18:57 +0800 Subject: [PATCH 035/133] save diffusers config --- .../opensora/train/train_causalvae.py | 5 +++-- examples/opensora_pku/opensora/utils/utils.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index ae64beaaa0..9cc2e6ff9f 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -24,7 +24,7 @@ from opensora.models.causalvideovae.model.utils.model_utils import resolve_str_to_obj from opensora.npu_config import npu_config from opensora.train.commons import create_loss_scaler, parse_args -from opensora.utils.utils import get_precision +from opensora.utils.utils import get_precision, save_diffusers_json from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, ProfilerCallback from mindone.trainers.checkpoint import CheckpointManager, resume_train_network @@ -67,7 +67,8 @@ def main(args): if rank_id == 0: logger.warning(f"Model will be initialized from config file {args.model_config}.") ae = model_cls.from_config(args.model_config, dtype=dtype, use_recompute=args.use_recompute) - + json_name = os.path.join(args.output_dir, "config.json") + save_diffusers_json(ae.config, json_name) if args.load_from_checkpoint is not None: ae.init_from_ckpt(args.load_from_checkpoint) # discriminator (D) diff --git a/examples/opensora_pku/opensora/utils/utils.py b/examples/opensora_pku/opensora/utils/utils.py index 7ba97b7a4b..a6890c975b 100644 --- a/examples/opensora_pku/opensora/utils/utils.py +++ b/examples/opensora_pku/opensora/utils/utils.py @@ -1,6 +1,7 @@ import argparse import collections import html +import json import logging import re import urllib.parse as ul @@ -24,6 +25,23 @@ logger = logging.getLogger(__name__) +# Custom JSON Encoder to serialize everything as strings +class StringifyJSONEncoder(json.JSONEncoder): + def default(self, obj): + # Convert the object to a string + return str(obj) + + +def save_diffusers_json(config, filename): + if not isinstance(config, dict): + config = dict(config) + + # Save the regular dictionary to a JSON file using the custom encoder + with open(filename, "w") as json_file: + json.dump(config, json_file, cls=StringifyJSONEncoder, indent=4) + logger.info(f"Save config file to {filename}") + + def to_2tuple(x): if isinstance(x, collections.abc.Iterable): return x From 3835bfcbaebb88feadcf2b8bff62a06a73bc29c4 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:32:01 +0800 Subject: [PATCH 036/133] fix error --- examples/opensora_pku/examples/rec_image.py | 4 ++-- examples/opensora_pku/examples/rec_video.py | 3 ++- examples/opensora_pku/examples/rec_video_folder.py | 10 +++++----- examples/opensora_pku/opensora/npu_config.py | 1 + 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/opensora_pku/examples/rec_image.py b/examples/opensora_pku/examples/rec_image.py index 5b2cb0110b..d7f64a2d8d 100644 --- a/examples/opensora_pku/examples/rec_image.py +++ b/examples/opensora_pku/examples/rec_image.py @@ -77,7 +77,7 @@ def main(args): ) else: state_dict = None - + vae = None if args.model_config is not None: assert os.path.exists(args.model_config), f"`model_config` does not exist! {args.model_config}" pattern = r"^([A-Za-z]+)Model" @@ -92,7 +92,7 @@ def main(args): else: logger.warning(f"Incorrect ae name, must be one of {ae_wrapper.keys()}") - vae = None + kwarg = { "state_dict": state_dict, "use_safetensors": True, diff --git a/examples/opensora_pku/examples/rec_video.py b/examples/opensora_pku/examples/rec_video.py index d7a3ce4d4b..d1eea5234d 100644 --- a/examples/opensora_pku/examples/rec_video.py +++ b/examples/opensora_pku/examples/rec_video.py @@ -116,6 +116,7 @@ def main(args): else: state_dict = None + vae = None if args.model_config is not None: assert os.path.exists(args.model_config), f"`model_config` does not exist! {args.model_config}" pattern = r"^([A-Za-z]+)Model" @@ -130,7 +131,7 @@ def main(args): else: logger.warning(f"Incorrect ae name, must be one of {ae_wrapper.keys()}") - vae = None + kwarg = { "state_dict": state_dict, "use_safetensors": True, diff --git a/examples/opensora_pku/examples/rec_video_folder.py b/examples/opensora_pku/examples/rec_video_folder.py index a90da80c84..727116a6bd 100644 --- a/examples/opensora_pku/examples/rec_video_folder.py +++ b/examples/opensora_pku/examples/rec_video_folder.py @@ -62,6 +62,7 @@ def main(args): else: state_dict = None + vae = None if args.model_config is not None: assert os.path.exists(args.model_config), f"`model_config` does not exist! {args.model_config}" pattern = r"^([A-Za-z]+)Model" @@ -76,7 +77,7 @@ def main(args): else: logger.warning(f"Incorrect ae name, must be one of {ae_wrapper.keys()}") - vae = None + kwarg = { "state_dict": state_dict, "use_safetensors": True, @@ -130,7 +131,7 @@ def main(args): ) num_batches = dataloader.get_dataset_size() logger.info("Number of batches: %d", num_batches) - ds_iter = dataloader.create_dict_iterator(1) + ds_iter = dataloader.create_dict_iterator(1, output_numpy=True) # ---- Prepare Dataset # ---- Inference ---- @@ -140,12 +141,11 @@ def main(args): else: x = batch["video"] file_paths = batch["path"] - x = x.to(dtype=dtype) # b c t h w + x = ms.Tensor(x, dtype=dtype) # b c t h w latents = vae.encode(x) video_recon = vae.decode(latents) for idx, video in enumerate(video_recon): - file_paths = eval(str(file_paths).replace("/n", ",")) - file_name = os.path.basename(file_paths[idx]) + file_name = os.path.basename(str(file_paths[idx])) if ".avi" in os.path.basename(file_name): file_name = file_name.replace(".avi", ".mp4") output_path = os.path.join(generated_video_dir, file_name) diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py index 2afde53e51..4b5e72126a 100644 --- a/examples/opensora_pku/opensora/npu_config.py +++ b/examples/opensora_pku/opensora/npu_config.py @@ -67,6 +67,7 @@ def set_npu_env(self, args): rank_id, device_num = init_env( mode=args.mode, device_target=args.device, + distributed=args.use_parallel, precision_mode=getattr(args, "precision_mode", None), jit_level=getattr(args, "jit_level", None), jit_syntax_level=getattr(args, "jit_syntax_level", "strict"), From 67c8f28f3997434c737260c70f6018856b7cf178 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:41:37 +0800 Subject: [PATCH 037/133] conv dtype bf16 --- examples/opensora_pku/opensora/npu_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py index 4b5e72126a..0cb4d9cf8e 100644 --- a/examples/opensora_pku/opensora/npu_config.py +++ b/examples/opensora_pku/opensora/npu_config.py @@ -44,7 +44,7 @@ def __init__(self): self.original_run_dtype = None self.replaced_type = ms.float32 - self.conv_dtype = ms.float16 + self.conv_dtype = ms.bfloat16 # FIXME: torch uses float16 if self.enable_FA and self.enable_FP32: self.inf_float = -10000.0 else: From 5642d1afc2ffdab60c103cf137dc66bc05af8315 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:50:05 +0800 Subject: [PATCH 038/133] conv 3d tranpose to fp16 --- .../models/causalvideovae/model/modules/wavelet.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py index f7c52258e4..2f96f47c22 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py @@ -1,3 +1,5 @@ +import logging + import mindspore as ms from mindspore import Tensor, mint, nn, ops @@ -9,6 +11,8 @@ except ImportError: npu_config = None +logger = logging.getLogger(__name__) + class HaarWaveletTransform3D(nn.Cell): def __init__(self, dtype=ms.float32) -> None: @@ -119,10 +123,11 @@ def __init__(self, enable_cached=False, dtype=ms.float16, *args, **kwargs) -> No self.dtype = dtype - if self.dtype == ms.float32 or self.dtype == ms.bfloat16: + if self.dtype == ms.bfloat16: + # Conv3dTranspose does not support bf16 self.dtype = ms.float16 dtype = ms.float16 - print("conv3d transpose layer is forced to fp16") + logger.warning("conv3d transpose layer does not support ms.bfloat16, and is forced to use ms.float16!") self.h = Tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 self.g = Tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 From 6f543d17e28729255ccd6d38abd8506e52c94e9f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:51:25 +0800 Subject: [PATCH 039/133] init loss scale 65536 --- examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh | 2 +- .../scripts/causalvae/train_with_gan_loss_multi_device.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index 9ced6f905d..d35b6abdaf 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -19,7 +19,7 @@ python opensora/train/train_causalvae.py \ --clip_grad True \ --weight_decay 0.0 \ --mode 1 \ - --init_loss_scale 1 \ + --init_loss_scale 65536 \ --jit_level "O0" \ --use_discriminator True \ --use_ema True\ diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index 7034543230..d5817e4e19 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -26,7 +26,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --clip_grad True \ --weight_decay 0.0 \ --mode 1 \ - --init_loss_scale 1 \ + --init_loss_scale 65536 \ --jit_level "O0" \ --use_discriminator True \ --use_parallel True \ From 16efbb8832c1b3a549a335747919c4175cafd827 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:53:27 +0800 Subject: [PATCH 040/133] updates --- .../opensora/models/causalvideovae/model/modules/wavelet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py index 2f96f47c22..539cd70c2d 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py @@ -118,7 +118,7 @@ def construct(self, x): class InverseHaarWaveletTransform3D(nn.Cell): - def __init__(self, enable_cached=False, dtype=ms.float16, *args, **kwargs) -> None: + def __init__(self, enable_cached=False, dtype=ms.float32, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.dtype = dtype From 5b8bf94fd5571aaa18f4f15cb93a74bfbd8f1e65 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 7 Nov 2024 11:36:43 +0800 Subject: [PATCH 041/133] conv2d transpose use fp16 --- .../opensora/models/causalvideovae/model/modules/wavelet.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py index 539cd70c2d..74368d7435 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py @@ -123,11 +123,10 @@ def __init__(self, enable_cached=False, dtype=ms.float32, *args, **kwargs) -> No self.dtype = dtype - if self.dtype == ms.bfloat16: - # Conv3dTranspose does not support bf16 + if self.dtype != ms.float16: + # Conv3dTranspose is forced to fp16 self.dtype = ms.float16 dtype = ms.float16 - logger.warning("conv3d transpose layer does not support ms.bfloat16, and is forced to use ms.float16!") self.h = Tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 self.g = Tensor([[[1, -1], [1, -1]], [[1, -1], [1, -1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 From 65bbda20288ab96cfa24d98370b7de15e38f13c9 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:29:26 +0800 Subject: [PATCH 042/133] change conv2d initializer --- .../model/modules/resnet_block.py | 42 +++++++++++++++++-- .../model/modules/updownsample.py | 32 ++++++++++++-- .../model/vae/modeling_wfvae.py | 16 +++++++ 3 files changed, 83 insertions(+), 7 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py index eb6a4c697d..7e5ba09ff3 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py @@ -1,5 +1,8 @@ +import math + import mindspore as ms from mindspore import nn +from mindspore.common.initializer import HeUniform, Uniform try: from opensora.npu_config import npu_config @@ -81,21 +84,52 @@ def __init__( self.norm1 = Normalize(in_channels, norm_type=norm_type) self.conv1 = nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + in_channels, + out_channels, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(out_channels)), ).to_float(dtype) self.norm2 = Normalize(out_channels, norm_type=norm_type) self.dropout = nn.Dropout(p=dropout) self.conv2 = nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + out_channels, + out_channels, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(out_channels)), ).to_float(dtype) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + in_channels, + out_channels, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(out_channels)), ).to_float(dtype) else: self.nin_shortcut = nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, pad_mode="valid", has_bias=True + in_channels, + out_channels, + kernel_size=1, + stride=1, + pad_mode="valid", + has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(out_channels)), ).to_float(dtype) @video_to_image diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py index 0f7f12043d..f02adcd549 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py @@ -1,9 +1,11 @@ +import math from typing import Tuple, Union from opensora.npu_config import npu_config import mindspore as ms from mindspore import mint, nn, ops +from mindspore.common.initializer import HeUniform, Uniform from .conv import CausalConv3d from .ops import cast_tuple, video_to_image @@ -16,7 +18,15 @@ def __init__(self, in_channels, out_channels, with_conv=True, dtype=ms.float32): self.with_conv = with_conv if self.with_conv: self.conv = nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + in_channels, + out_channels, + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(out_channels)), ).to_float(self.dtype) @video_to_image @@ -39,11 +49,27 @@ def __init__(self, in_channels, out_channels, undown=False, dtype=ms.float32): # no asymmetric padding in torch conv, must do it ourselves if self.undown: self.conv = nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode="pad", has_bias=True + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(out_channels)), ).to_float(self.dtype) else: self.conv = nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=2, padding=0, pad_mode="pad", has_bias=True + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=0, + pad_mode="pad", + has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(out_channels)), ).to_float(self.dtype) @video_to_image diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py index e323c10238..ee370a5c02 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -1,4 +1,5 @@ import logging +import math import os from typing import List @@ -6,6 +7,7 @@ import mindspore as ms from mindspore import mint, nn +from mindspore.common.initializer import HeUniform, Uniform from mindone.diffusers import __version__ from mindone.diffusers.configuration_utils import register_to_config @@ -58,6 +60,8 @@ def __init__( padding=1, pad_mode="pad", has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(base_channels)), ).to_float(dtype), *[ ResnetBlock2D( @@ -80,6 +84,8 @@ def __init__( padding=1, pad_mode="pad", has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(base_channels * 2)), ).to_float(dtype), *[ ResnetBlock3D( @@ -107,6 +113,8 @@ def __init__( padding=1, pad_mode="pad", has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(energy_flow_hidden_size)), ).to_float(dtype) self.connect_l2 = Conv2d( 24, @@ -116,6 +124,8 @@ def __init__( padding=1, pad_mode="pad", has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(energy_flow_hidden_size)), ).to_float(dtype) # Mid mid_layers = [ @@ -289,6 +299,8 @@ def __init__( padding=1, pad_mode="pad", has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(l1_channels)), ).to_float(dtype), ) self.connect_l2 = nn.SequentialCell( @@ -310,6 +322,8 @@ def __init__( padding=1, pad_mode="pad", has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(24)), ).to_float(dtype), ) # Out @@ -322,6 +336,8 @@ def __init__( padding=1, pad_mode="pad", has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(24)), ).to_float(dtype) self.inverse_wavelet_tranform_l1 = resolve_str_to_obj(l1_upsample_wavelet)(dtype=dtype) From c54809ae0a7280f0d9e5ed72c95dbf8a7613fe1e Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:34:27 +0800 Subject: [PATCH 043/133] conv3d initializer --- .../model/losses/discriminator.py | 31 +++++++++++++++++-- .../causalvideovae/model/modules/conv.py | 15 ++++++++- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py index 03c28c8967..1d6185e090 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py @@ -1,10 +1,11 @@ import functools -from typing import Tuple, Union +import math from opensora.npu_config import npu_config import mindspore as ms from mindspore import nn +from mindspore.common.initializer import HeUniform, Uniform def weights_init(m): @@ -43,7 +44,17 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f kw = 3 padw = 1 sequence = [ - nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, pad_mode="pad", padding=padw, has_bias=True), + nn.Conv3d( + input_nc, + ndf, + kernel_size=kw, + stride=2, + pad_mode="pad", + padding=padw, + has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(ndf)), + ), nn.LeakyReLU(0.2).to_float(self.dtype), ] nf_mult = 1 @@ -60,6 +71,8 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f padding=padw, pad_mode="pad", has_bias=use_bias, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult)), ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2).to_float(self.dtype), @@ -76,13 +89,25 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f padding=padw, pad_mode="pad", has_bias=use_bias, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult)), ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2).to_float(self.dtype), ] sequence += [ - Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, pad_mode="pad", has_bias=True) + nn.Conv3d( + ndf * nf_mult, + 1, + kernel_size=kw, + stride=1, + padding=padw, + pad_mode="pad", + has_bias=True, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(1)), + ) ] # output 1 channel prediction map self.main = nn.CellList(sequence) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py index bbf1c4b360..e694bef71e 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py @@ -1,5 +1,8 @@ +import math from typing import Tuple, Union +from mindspore.common.initializer import HeUniform, Uniform + try: from opensora.npu_config import npu_config except ImportError: @@ -76,7 +79,15 @@ def __init__( self.stride = cast_tuple(self.stride, 3) if self.padding == 0: self.conv = nn.Conv3d( - chan_in, chan_out, self.kernel_size, stride=self.stride, pad_mode="valid", has_bias=bias, **kwargs + chan_in, + chan_out, + self.kernel_size, + stride=self.stride, + pad_mode="valid", + has_bias=bias, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(chan_out)), + **kwargs, ) else: self.padding = list(cast_tuple(self.padding, 6)) @@ -91,6 +102,8 @@ def __init__( padding=tuple(self.padding), pad_mode="pad", has_bias=bias, + weight_init=HeUniform(negative_slope=math.sqrt(5)), + bias_init=Uniform(scale=1 / math.sqrt(chan_out)), **kwargs, ) self.enable_cached = enable_cached From ce3d314a14c0d5c373e58921703526dd16f9a5c3 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:48:34 +0800 Subject: [PATCH 044/133] correct to fan_int --- .../causalvideovae/model/losses/discriminator.py | 8 ++++---- .../models/causalvideovae/model/modules/conv.py | 8 ++++++-- .../causalvideovae/model/modules/resnet_block.py | 8 ++++---- .../causalvideovae/model/modules/updownsample.py | 6 +++--- .../causalvideovae/model/vae/modeling_wfvae.py | 14 +++++++------- 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py index 1d6185e090..80f6df9b9f 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py @@ -53,7 +53,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f padding=padw, has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(ndf)), + bias_init=Uniform(scale=1 / math.sqrt(input_nc * kw * kw * kw)), ), nn.LeakyReLU(0.2).to_float(self.dtype), ] @@ -72,7 +72,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f pad_mode="pad", has_bias=use_bias, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult)), + bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult_prev * kw * kw * kw)), ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2).to_float(self.dtype), @@ -90,7 +90,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f pad_mode="pad", has_bias=use_bias, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult)), + bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult_prev * kw * kw * kw)), ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2).to_float(self.dtype), @@ -106,7 +106,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(1)), + bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult * kw * kw * kw)), ) ] # output 1 channel prediction map self.main = nn.CellList(sequence) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py index e694bef71e..4eb955a37c 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py @@ -86,7 +86,9 @@ def __init__( pad_mode="valid", has_bias=bias, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(chan_out)), + bias_init=Uniform( + scale=1 / math.sqrt(chan_in * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]) + ), **kwargs, ) else: @@ -103,7 +105,9 @@ def __init__( pad_mode="pad", has_bias=bias, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(chan_out)), + bias_init=Uniform( + scale=1 / math.sqrt(chan_in * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]) + ), **kwargs, ) self.enable_cached = enable_cached diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py index 7e5ba09ff3..40dd40cebf 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/resnet_block.py @@ -92,7 +92,7 @@ def __init__( padding=1, has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(out_channels)), + bias_init=Uniform(scale=1 / math.sqrt(in_channels * 3 * 3)), ).to_float(dtype) self.norm2 = Normalize(out_channels, norm_type=norm_type) self.dropout = nn.Dropout(p=dropout) @@ -105,7 +105,7 @@ def __init__( padding=1, has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(out_channels)), + bias_init=Uniform(scale=1 / math.sqrt(out_channels * 3 * 3)), ).to_float(dtype) if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -118,7 +118,7 @@ def __init__( padding=1, has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(out_channels)), + bias_init=Uniform(scale=1 / math.sqrt(in_channels * 3 * 3)), ).to_float(dtype) else: self.nin_shortcut = nn.Conv2d( @@ -129,7 +129,7 @@ def __init__( pad_mode="valid", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(out_channels)), + bias_init=Uniform(scale=1 / math.sqrt(in_channels * 3 * 3)), ).to_float(dtype) @video_to_image diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py index f02adcd549..3e7c7626e4 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py @@ -26,7 +26,7 @@ def __init__(self, in_channels, out_channels, with_conv=True, dtype=ms.float32): padding=1, has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(out_channels)), + bias_init=Uniform(scale=1 / math.sqrt(in_channels * 3 * 3)), ).to_float(self.dtype) @video_to_image @@ -57,7 +57,7 @@ def __init__(self, in_channels, out_channels, undown=False, dtype=ms.float32): pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(out_channels)), + bias_init=Uniform(scale=1 / math.sqrt(in_channels * 3 * 3)), ).to_float(self.dtype) else: self.conv = nn.Conv2d( @@ -69,7 +69,7 @@ def __init__(self, in_channels, out_channels, undown=False, dtype=ms.float32): pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(out_channels)), + bias_init=Uniform(scale=1 / math.sqrt(in_channels * 3 * 3)), ).to_float(self.dtype) @video_to_image diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py index ee370a5c02..b1f0dc2fb8 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -61,7 +61,7 @@ def __init__( pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(base_channels)), + bias_init=Uniform(scale=1 / math.sqrt(24 * 3 * 3)), ).to_float(dtype), *[ ResnetBlock2D( @@ -85,7 +85,7 @@ def __init__( pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(base_channels * 2)), + bias_init=Uniform(scale=1 / math.sqrt((base_channels + energy_flow_hidden_size) * 3 * 3)), ).to_float(dtype), *[ ResnetBlock3D( @@ -114,7 +114,7 @@ def __init__( pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(energy_flow_hidden_size)), + bias_init=Uniform(scale=1 / math.sqrt(l1_channels * 3 * 3)), ).to_float(dtype) self.connect_l2 = Conv2d( 24, @@ -125,7 +125,7 @@ def __init__( pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(energy_flow_hidden_size)), + bias_init=Uniform(scale=1 / math.sqrt(24 * 3 * 3)), ).to_float(dtype) # Mid mid_layers = [ @@ -300,7 +300,7 @@ def __init__( pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(l1_channels)), + bias_init=Uniform(scale=1 / math.sqrt(base_channels * 3 * 3)), ).to_float(dtype), ) self.connect_l2 = nn.SequentialCell( @@ -323,7 +323,7 @@ def __init__( pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(24)), + bias_init=Uniform(scale=1 / math.sqrt(base_channels * 3 * 3)), ).to_float(dtype), ) # Out @@ -337,7 +337,7 @@ def __init__( pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(24)), + bias_init=Uniform(scale=1 / math.sqrt(base_channels * 3 * 3)), ).to_float(dtype) self.inverse_wavelet_tranform_l1 = resolve_str_to_obj(l1_upsample_wavelet)(dtype=dtype) From 22ed93693d547541494dfaed1ca86d42900654c6 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:51:07 +0800 Subject: [PATCH 045/133] print rec loss and p_loss & use fp32 for p_loss --- .../opensora/models/causalvideovae/model/losses/lpips.py | 1 + .../opensora/models/causalvideovae/model/losses/net_with_loss.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/lpips.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/lpips.py index 4c31a52a2e..bc0a602df5 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/lpips.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/lpips.py @@ -138,6 +138,7 @@ def construct(self, X): def normalize_tensor(x, eps=1e-10): + x = x.to(ms.float32) norm_factor = mint.sqrt((x**2).sum(1, keepdims=True)) return x / (norm_factor + eps) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py index 910a7235b3..76b34a0061 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py @@ -113,6 +113,7 @@ def loss_function( if self.perceptual_weight > 0: p_loss = self.perceptual_loss(x, recons) rec_loss = rec_loss + self.perceptual_weight * p_loss + # print(f"rec_loss {rec_loss.sum()/rec_loss.shape[0]}, " + f"p_loss: {p_loss.sum()/p_loss.shape[0]}" if self.perceptual_weight > 0 else "") nll_loss = rec_loss / mint.exp(self.logvar) + self.logvar if weights is not None: From 277b716d10df547d8a45f3c0e9453f58cb6fd558 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 12 Nov 2024 17:11:53 +0800 Subject: [PATCH 046/133] updates for graph mode --- .../model/losses/net_with_loss.py | 14 +++++++-- .../causalvideovae/model/modules/attention.py | 31 ++++++++++--------- .../causalvideovae/model/modules/conv.py | 4 +++ .../causalvideovae/model/modules/wavelet.py | 27 +++++++++++----- examples/opensora_pku/opensora/npu_config.py | 8 ++--- .../opensora/train/train_causalvae.py | 2 ++ .../scripts/causalvae/rec_image.sh | 1 + .../scripts/causalvae/rec_video.sh | 1 + .../scripts/causalvae/rec_video_folder.sh | 2 ++ .../scripts/causalvae/train_with_gan_loss.sh | 4 ++- .../train_with_gan_loss_multi_device.sh | 4 ++- 11 files changed, 68 insertions(+), 30 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py index 76b34a0061..c9d0bdf3cf 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py @@ -45,9 +45,12 @@ def __init__( learn_logvar: bool = False, wavelet_weight=0.01, loss_type: str = "l1", + print_losses: bool = False, ): super().__init__() + self.print_losses = print_losses + # build perceptual models for loss compute self.autoencoder = autoencoder # TODO: set dtype for LPIPS ? @@ -113,7 +116,13 @@ def loss_function( if self.perceptual_weight > 0: p_loss = self.perceptual_loss(x, recons) rec_loss = rec_loss + self.perceptual_weight * p_loss - # print(f"rec_loss {rec_loss.sum()/rec_loss.shape[0]}, " + f"p_loss: {p_loss.sum()/p_loss.shape[0]}" if self.perceptual_weight > 0 else "") + if self.print_losses: + print( + f"rec_loss {(rec_loss.sum()/rec_loss.shape[0]).asnumpy()}, " + + f"p_loss: {(p_loss.sum()/p_loss.shape[0]).asnumpy()}" + if self.perceptual_weight > 0 + else "" + ) nll_loss = rec_loss / mint.exp(self.logvar) + self.logvar if weights is not None: @@ -152,7 +161,8 @@ def loss_function( # d_weight = self.calculate_adaptive_weight(mean_nll_loss, g_loss, last_layer=last_layer) d_weight = self.disc_weight loss += d_weight * self.disc_factor * g_loss - # print(f"nll_loss: {mean_weighted_nll_loss.asnumpy():.4f}, kl_loss: {kl_loss.asnumpy():.4f}") + if self.print_losses: + print(f"nll_loss: {mean_weighted_nll_loss.asnumpy():.4f}, kl_loss: {kl_loss.asnumpy():.4f}") """ split = "train" diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/attention.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/attention.py index 3803ba9d85..692ea123b1 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/attention.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/attention.py @@ -7,7 +7,7 @@ from .normalize import Normalize try: - from opensora.npu_config import npu_config, set_run_dtype + from opensora.npu_config import npu_config except ImportError: npu_config = None @@ -88,19 +88,22 @@ def construct(self, x): dtype = ms.bfloat16 else: dtype = None - with set_run_dtype(q, dtype): - query, key, value = npu_config.set_current_run_dtype([q, k, v]) - hidden_states = npu_config.run_attention( - query, - key, - value, - attention_mask=None, - input_layout="BSH", - head_dim=c // 2, - head_num=2, # FIXME: different from torch. To make head_dim 256 instead of 512 - ) - - attn_output = npu_config.restore_dtype(hidden_states) + npu_config.current_run_dtype = dtype + npu_config.original_run_dtype = q.dtype + # with set_run_dtype(q, dtype): # graph mode does not support it + query, key, value = npu_config.set_current_run_dtype([q, k, v]) + hidden_states = npu_config.run_attention( + query, + key, + value, + attention_mask=None, + input_layout="BSH", + head_dim=c // 2, + head_num=2, # FIXME: different from torch. To make head_dim 256 instead of 512 + ) + npu_config.current_run_dtype = None + npu_config.original_run_dtype = None + attn_output = npu_config.restore_dtype(hidden_states) attn_output = attn_output.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3) h_ = self.proj_out(attn_output) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py index 4eb955a37c..bf8a5b34f0 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py @@ -1,6 +1,7 @@ import math from typing import Tuple, Union +import mindspore as ms from mindspore.common.initializer import HeUniform, Uniform try: @@ -77,6 +78,7 @@ def __init__( self.stride = kwargs.pop("stride", 1) self.padding = kwargs.pop("padding", 0) self.stride = cast_tuple(self.stride, 3) + conv_dtype = npu_config.conv_dtype if npu_config is not None else ms.bfloat16 if self.padding == 0: self.conv = nn.Conv3d( chan_in, @@ -89,6 +91,7 @@ def __init__( bias_init=Uniform( scale=1 / math.sqrt(chan_in * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]) ), + dtype=conv_dtype, **kwargs, ) else: @@ -108,6 +111,7 @@ def __init__( bias_init=Uniform( scale=1 / math.sqrt(chan_in * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]) ), + dtype=conv_dtype, **kwargs, ) self.enable_cached = enable_cached diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py index 74368d7435..a5df5fcb13 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py @@ -69,35 +69,46 @@ def construct(self, x): x = x.reshape(-1, 1, *x.shape[-3:]) low_low_low = self.h_conv(x) low_low_low = low_low_low.reshape( - b, low_low_low.shape[0] // b, *low_low_low.shape[-3:] + b, low_low_low.shape[0] // b, low_low_low.shape[-3], low_low_low.shape[-2], low_low_low.shape[-1] ) # (b c) 1 t h w -> b c t h w + low_low_high = self.g_conv(x) low_low_high = low_low_high.reshape( - b, low_low_high.shape[0] // b, *low_low_high.shape[-3:] + b, low_low_high.shape[0] // b, low_low_high.shape[-3], low_low_high.shape[-2], low_low_high.shape[-1] ) # (b c) 1 t h w -> b c t h w + low_high_low = self.hh_conv(x) low_high_low = low_high_low.reshape( - b, low_high_low.shape[0] // b, *low_high_low.shape[-3:] + b, low_high_low.shape[0] // b, low_high_low.shape[-3], low_high_low.shape[-2], low_high_low.shape[-1] ) # (b c) 1 t h w -> b c t h w + low_high_high = self.gh_conv(x) low_high_high = low_high_high.reshape( - b, low_high_high.shape[0] // b, *low_high_high.shape[-3:] + b, low_high_high.shape[0] // b, low_high_high.shape[-3], low_high_high.shape[-2], low_high_high.shape[-1] ) # (b c) 1 t h w -> b c t h w + high_low_low = self.h_v_conv(x) high_low_low = high_low_low.reshape( - b, high_low_low.shape[0] // b, *high_low_low.shape[-3:] + b, high_low_low.shape[0] // b, high_low_low.shape[-3], high_low_low.shape[-2], high_low_low.shape[-1] ) # (b c) 1 t h w -> b c t h w + high_low_high = self.g_v_conv(x) high_low_high = high_low_high.reshape( - b, high_low_high.shape[0] // b, *high_low_high.shape[-3:] + b, high_low_high.shape[0] // b, high_low_high.shape[-3], high_low_high.shape[-2], high_low_high.shape[-1] ) # (b c) 1 t h w -> b c t h w + high_high_low = self.hh_v_conv(x) high_high_low = high_high_low.reshape( - b, high_high_low.shape[0] // b, *high_high_low.shape[-3:] + b, high_high_low.shape[0] // b, high_high_low.shape[-3], high_high_low.shape[-2], high_high_low.shape[-1] ) # (b c) 1 t h w -> b c t h w + high_high_high = self.gh_v_conv(x) high_high_high = high_high_high.reshape( - b, high_high_high.shape[0] // b, *high_high_high.shape[-3:] + b, + high_high_high.shape[0] // b, + high_high_high.shape[-3], + high_high_high.shape[-2], + high_high_high.shape[-1], ) # (b c) 1 t h w -> b c t h w output = mint.cat( diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py index 0cb4d9cf8e..ae3b3b7806 100644 --- a/examples/opensora_pku/opensora/npu_config.py +++ b/examples/opensora_pku/opensora/npu_config.py @@ -67,7 +67,7 @@ def set_npu_env(self, args): rank_id, device_num = init_env( mode=args.mode, device_target=args.device, - distributed=args.use_parallel, + distributed=getattr(args, "use_parallel", False), precision_mode=getattr(args, "precision_mode", None), jit_level=getattr(args, "jit_level", None), jit_syntax_level=getattr(args, "jit_syntax_level", "strict"), @@ -131,7 +131,7 @@ def _run(self, operator, x, tmp_dtype, out_dtype=None): if self.on_npu: if out_dtype is None: out_dtype = x.dtype - x = operator.to_float(tmp_dtype)(x.to(tmp_dtype)) + x = operator(x.to(tmp_dtype)) x = x.to(out_dtype) return x else: @@ -278,9 +278,9 @@ def ms_flash_attention( # If we did padding before calculate attention, undo it! if head_dim_padding > 0: if input_layout == "BNSD": - hidden_states = hidden_states_padded[..., :head_dim] + hidden_states = hidden_states_padded[:, :, :, :head_dim] else: - hidden_states = hidden_states_padded.view(Bs, query_tokens, heads, -1)[..., :head_dim] + hidden_states = hidden_states_padded.view(Bs, query_tokens, heads, -1)[:, :, :, :head_dim] hidden_states = hidden_states.view(Bs, query_tokens, -1) else: hidden_states = hidden_states_padded diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index 9cc2e6ff9f..ced2e94edd 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -104,6 +104,7 @@ def main(args): perceptual_weight=args.perceptual_weight, loss_type=args.loss_type, wavelet_weight=args.wavelet_weight, + print_losses=args.print_losses, ) disc_start = args.disc_start @@ -627,6 +628,7 @@ def parse_causalvae_train_args(parser): parser.add_argument("--logvar_init", type=float, default=0.0, help="") parser.add_argument("--wavelet_loss", action="store_true", help="") parser.add_argument("--wavelet_weight", type=float, default=0.1, help="") + parser.add_argument("--print_losses", action="store_true", help="Whether to print multiple losses during training") return parser diff --git a/examples/opensora_pku/scripts/causalvae/rec_image.sh b/examples/opensora_pku/scripts/causalvae/rec_image.sh index 38fc720683..dfb5c6e08c 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_image.sh +++ b/examples/opensora_pku/scripts/causalvae/rec_image.sh @@ -6,3 +6,4 @@ python examples/rec_image.py \ --device Ascend \ --short_size 512 \ --mode 1 \ + --jit_syntax_level lax \ diff --git a/examples/opensora_pku/scripts/causalvae/rec_video.sh b/examples/opensora_pku/scripts/causalvae/rec_video.sh index 4a9716c28b..244f7acc37 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_video.sh +++ b/examples/opensora_pku/scripts/causalvae/rec_video.sh @@ -11,3 +11,4 @@ python examples/rec_video.py \ --fps 30 \ --enable_tiling \ --mode 1 \ + --jit_syntax_level lax \ diff --git a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh index a52d7c3332..7b88f5c63a 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh +++ b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh @@ -13,3 +13,5 @@ python examples/rec_video_folder.py \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --enable_tiling \ --tile_overlap_factor 0.125 \ + --mode 1 \ + --jit_syntax_level lax \ diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index d35b6abdaf..8719e70e9c 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -30,4 +30,6 @@ python opensora/train/train_causalvae.py \ --disc_cls causalvideovae.model.losses.LPIPSWithDiscriminator3D \ --disc_start 0 \ --wavelet_loss \ - --wavelet_weight 0.1 + --wavelet_weight 0.1 \ + --mode 1 \ + --jit_syntax_level lax \ diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index d5817e4e19..f7d41ad885 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -38,4 +38,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --disc_cls causalvideovae.model.losses.LPIPSWithDiscriminator3D \ --disc_start 0 \ --wavelet_loss \ - --wavelet_weight 0.1 + --wavelet_weight 0.1 \ + --mode 1 \ + --jit_syntax_level lax \ From 9f5bf8895b33ec63b1f3d9a5c26758413c74c036 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:17:12 +0800 Subject: [PATCH 047/133] use ops.standard_normal --- .../models/causalvideovae/model/vae/modeling_wfvae.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py index b1f0dc2fb8..aeeda6b452 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -6,7 +6,7 @@ from opensora.npu_config import npu_config import mindspore as ms -from mindspore import mint, nn +from mindspore import mint, nn, ops from mindspore.common.initializer import HeUniform, Uniform from mindone.diffusers import __version__ @@ -443,7 +443,7 @@ def __init__( ) self.exp = mint.exp - self.stdnormal = mint.normal + self.stdnormal = ops.standard_normal self.update_parameters_name() # update parameter names to solve pname mismatch if use_recompute: @@ -577,7 +577,7 @@ def sample(self, mean, logvar): # sample z from latent distribution logvar = mint.clamp(logvar, -30.0, 20.0) std = self.exp(0.5 * logvar) - z = mean + std * self.stdnormal(size=mean.shape) + z = mean + std * self.stdnormal(mean.shape) return z From 51d695cfc71d35c6b051397ab598c9c8ffc3cb9e Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:22:00 +0800 Subject: [PATCH 048/133] default to print losses --- examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh | 1 + .../scripts/causalvae/train_with_gan_loss_multi_device.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index 8719e70e9c..7577b9ec8b 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -33,3 +33,4 @@ python opensora/train/train_causalvae.py \ --wavelet_weight 0.1 \ --mode 1 \ --jit_syntax_level lax \ + --print_losses diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index f7d41ad885..f9f87e10bc 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -41,3 +41,4 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --wavelet_weight 0.1 \ --mode 1 \ --jit_syntax_level lax \ + --print_losses From 60472fba86cdbd9a54c314f72e78d32dd6e4e94d Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:26:37 +0800 Subject: [PATCH 049/133] use max_grad_norm 0.1 for vae training --- examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh | 1 + .../scripts/causalvae/train_with_gan_loss_multi_device.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index 7577b9ec8b..83bae9b0de 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -17,6 +17,7 @@ python opensora/train/train_causalvae.py \ --optim adamw \ --betas 0.9 0.999 \ --clip_grad True \ + --max_grad_norm 0.1 \ --weight_decay 0.0 \ --mode 1 \ --init_loss_scale 65536 \ diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index f9f87e10bc..416384aa3f 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -24,6 +24,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --optim adamw \ --betas 0.9 0.999 \ --clip_grad True \ + --max_grad_norm 0.1 \ --weight_decay 0.0 \ --mode 1 \ --init_loss_scale 65536 \ From af81883f0a91821e2b49aeb64068bd098e5e9882 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:34:02 +0800 Subject: [PATCH 050/133] discriminator bf16 conv3d --- .../models/causalvideovae/model/losses/discriminator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py index 80f6df9b9f..8c83c9d4af 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py @@ -40,6 +40,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f use_bias = norm_layer.func != nn.BatchNorm3d else: use_bias = norm_layer != nn.BatchNorm3d + conv_dtype = npu_config.conv_dtype if npu_config is not None else ms.bfloat16 kw = 3 padw = 1 @@ -54,6 +55,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), bias_init=Uniform(scale=1 / math.sqrt(input_nc * kw * kw * kw)), + dtype=conv_dtype, ), nn.LeakyReLU(0.2).to_float(self.dtype), ] @@ -73,6 +75,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f has_bias=use_bias, weight_init=HeUniform(negative_slope=math.sqrt(5)), bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult_prev * kw * kw * kw)), + dtype=conv_dtype, ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2).to_float(self.dtype), @@ -91,6 +94,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f has_bias=use_bias, weight_init=HeUniform(negative_slope=math.sqrt(5)), bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult_prev * kw * kw * kw)), + dtype=conv_dtype, ), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2).to_float(self.dtype), @@ -107,6 +111,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), bias_init=Uniform(scale=1 / math.sqrt(ndf * nf_mult * kw * kw * kw)), + dtype=conv_dtype, ) ] # output 1 channel prediction map self.main = nn.CellList(sequence) From ecdb0d10374283bd0a61faac923fe6b7a09e749a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 14 Nov 2024 13:41:59 +0800 Subject: [PATCH 051/133] set gen and disc weight decay --- examples/opensora_pku/opensora/train/train_causalvae.py | 9 ++++++--- .../scripts/causalvae/train_with_gan_loss.sh | 3 +-- .../causalvae/train_with_gan_loss_multi_device.sh | 3 +-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index ced2e94edd..6e490da12d 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -254,7 +254,7 @@ def main(args): name=args.optim, betas=args.betas, group_strategy=args.group_strategy, - weight_decay=args.weight_decay, + weight_decay=args.gen_wd, lr=lr, ) loss_scaler_ae = create_loss_scaler(args) @@ -267,7 +267,7 @@ def main(args): name=args.optim, lr=lr, # since lr is a shared list group_strategy=args.group_strategy, - weight_decay=args.weight_decay, + weight_decay=args.disc_wd, ) loss_scaler_disc = create_loss_scaler(args) scaling_sens_d = loss_scaler_disc.loss_scale_value @@ -347,7 +347,8 @@ def main(args): f"Rescale size: {args.resolution}", f"Crop size: {args.resolution}", f"Number of frames: {args.video_num_frames}", - f"Weight decay: {args.weight_decay}", + f"Weight decay: generator {args.gen_wd}" + + (f", discriminator {args.disc_wd}" if args.use_discriminator else ""), f"Grad accumulation steps: {args.gradient_accumulation_steps}", f"Num of training steps: {total_train_steps}", f"Loss scaler: {args.loss_scaler_type}", @@ -629,6 +630,8 @@ def parse_causalvae_train_args(parser): parser.add_argument("--wavelet_loss", action="store_true", help="") parser.add_argument("--wavelet_weight", type=float, default=0.1, help="") parser.add_argument("--print_losses", action="store_true", help="Whether to print multiple losses during training") + parser.add_argument("--gen_wd", type=float, default=1e-4, help="weight decay for generator") + parser.add_argument("--disc_wd", type=float, default=0.01, help="weight decay for discriminator") return parser diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index 83bae9b0de..931d3ad985 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -17,8 +17,7 @@ python opensora/train/train_causalvae.py \ --optim adamw \ --betas 0.9 0.999 \ --clip_grad True \ - --max_grad_norm 0.1 \ - --weight_decay 0.0 \ + --max_grad_norm 1.0 \ --mode 1 \ --init_loss_scale 65536 \ --jit_level "O0" \ diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index 416384aa3f..73215d0767 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -24,8 +24,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --optim adamw \ --betas 0.9 0.999 \ --clip_grad True \ - --max_grad_norm 0.1 \ - --weight_decay 0.0 \ + --max_grad_norm 1.0 \ --mode 1 \ --init_loss_scale 65536 \ --jit_level "O0" \ From bac8c79c0bcf275b90d33c9485fac9321d55348f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:03:31 +0800 Subject: [PATCH 052/133] change training steps order: one loss per step --- .../opensora/train/train_causalvae.py | 104 ++++++++++++------ 1 file changed, 73 insertions(+), 31 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index 6e490da12d..f4eaccad87 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -11,7 +11,7 @@ import yaml import mindspore as ms -from mindspore import Model +from mindspore import Model, nn from mindspore.train.callback import TimeMonitor sys.path.append(".") @@ -38,6 +38,27 @@ logger = logging.getLogger(__name__) +def set_train(modules): + for module in modules: + if isinstance(module, nn.Cell): + module.set_train(True) + + +def set_eval(modules): + for module in modules: + if isinstance(module, nn.Cell): + module.set_train(False) + + +def set_modules_requires_grad(modules, requires_grad): + for module in modules: + if isinstance(module, nn.Cell): + for param in module.get_parameters(): + param.requires_grad = requires_grad + elif isinstance(module, ms.Parameter): + module.requires_grad = requires_grad + + def main(args): # 1. init rank_id, device_num = npu_config.set_npu_env(args) @@ -247,8 +268,10 @@ def main(args): update_logvar = False # in torch, ae_with_loss.logvar is not updated. if update_logvar: ae_params_to_update = [ae_with_loss.autoencoder.trainable_params(), ae_with_loss.logvar] + ae_modules_to_update = [ae_with_loss.autoencoder, ae_with_loss.logvar] else: ae_params_to_update = ae_with_loss.autoencoder.trainable_params() + ae_modules_to_update = [ae_with_loss.autoencoder] optim_ae = create_optimizer( ae_params_to_update, name=args.optim, @@ -423,48 +446,67 @@ def main(args): for epoch in range(start_epoch, args.epochs): start_time_e = time.time() + set_train(ae_modules_to_update) for step, data in enumerate(ds_iter): start_time_s = time.time() x = data["video"] global_step = epoch * dataset_size + step - global_step = ms.Tensor(global_step, dtype=ms.int64) - - # NOTE: inputs must match the order in GeneratorWithLoss.construct - loss_ae_t, overflow, scaling_sens = training_step_ae(x, global_step) - if isinstance(scaling_sens, ms.Parameter): - scaling_sens = scaling_sens.value() - if global_step >= disc_start: - loss_disc_t, overflow_d, scaling_sens_d = training_step_disc(x, global_step) - if isinstance(scaling_sens_d, ms.Parameter): - scaling_sens_d = scaling_sens_d.value() + if global_step % 2 == 1 and global_step >= disc_start: + set_modules_requires_grad(ae_modules_to_update, False) + step_gen = False + step_dis = True + else: + set_modules_requires_grad(ae_modules_to_update, True) + step_gen = True + step_dis = False + assert step_gen or step_dis, "You should backward either Gen or Dis in a step." + global_step = ms.Tensor(global_step, dtype=ms.int64) cur_global_step = epoch * dataset_size + step + 1 # starting from 1 for logging - if overflow: - logger.warning( - f"Overflow occurs in step {cur_global_step} in autoencoder" - + (", drop update." if args.drop_overflow_update else ", still update.") - ) - if global_step >= disc_start and overflow_d: - logger.warning( - f"Overflow occurs in step {cur_global_step} in discriminator" - + (", drop update." if args.drop_overflow_update else ", still update.") - ) + # Generator Step + if step_gen: + # NOTE: inputs must match the order in GeneratorWithLoss.construct + loss_ae_t, overflow, scaling_sens = training_step_ae(x, global_step) + if isinstance(scaling_sens, ms.Parameter): + scaling_sens = scaling_sens.value() + + if overflow: + logger.warning( + f"Overflow occurs in step {cur_global_step} in autoencoder" + + (", drop update." if args.drop_overflow_update else ", still update.") + ) + # Discriminator Step + if step_dis: + if global_step >= disc_start: + loss_disc_t, overflow_d, scaling_sens_d = training_step_disc(x, global_step) + if isinstance(scaling_sens_d, ms.Parameter): + scaling_sens_d = scaling_sens_d.value() + if overflow_d: + logger.warning( + f"Overflow occurs in step {cur_global_step} in discriminator" + + (", drop update." if args.drop_overflow_update else ", still update.") + ) # log step_time = time.time() - start_time_s if step % args.log_interval == 0: - loss_ae = float(loss_ae_t.asnumpy()) - logger.info( - f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, ae loss scaler {scaling_sens}," - + f" Step time: {step_time*1000:.2f}ms" - ) - if global_step >= disc_start: + if step_gen: + loss_ae = float(loss_ae_t.asnumpy()) + logger.info( + f"E: {epoch+1}, S: {step+1}, Loss ae: {loss_ae:.4f}, ae loss scaler {scaling_sens}," + + f" Step time: {step_time*1000:.2f}ms" + ) + loss_disc = -1 # no discriminator loss, dummy value + if step_dis and global_step >= disc_start: loss_disc = float(loss_disc_t.asnumpy()) - logger.info(f"Loss disc: {loss_disc:.4f}, disc loss scaler {scaling_sens_d}") - loss_log_file.write(f"{cur_global_step}\t{loss_ae:.7f}\t{loss_disc:.7f}\t{step_time:.2f}\n") - else: - loss_log_file.write(f"{cur_global_step}\t{loss_ae:.7f}\t{0.0}\t{step_time:.2f}\n") + logger.info( + f"E: {epoch+1}, S: {step+1}, Loss disc: {loss_disc:.4f}, disc loss scaler {scaling_sens_d}," + + f" Step time: {step_time*1000:.2f}ms" + ) + loss_ae = -1 # no generator loss, dummy value + + loss_log_file.write(f"{cur_global_step}\t{loss_ae:.7f}\t{loss_disc:.7f}\t{step_time:.2f}\n") loss_log_file.flush() if rank_id == 0 and step_mode: From 5f4a2ff0ad576e5a5a3e7c0d9936d06c6cd27127 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:26:58 +0800 Subject: [PATCH 053/133] update printing log --- .../model/losses/net_with_loss.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py index c9d0bdf3cf..915de1886e 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py @@ -111,18 +111,26 @@ def loss_function( # 2.1 reconstruction loss in pixels rec_loss = self.loss_func(x, recons) - + if self.print_losses: + print(f"rec_loss {(rec_loss.sum()/rec_loss.shape[0]).asnumpy()}, ") + # debug + # if (rec_loss.sum()/rec_loss.shape[0]).asnumpy() <0: + # print("negative loss!") + # print(f"x min {x.min()}, max {x.max()}, mean {x.mean()}") + # print(f"recons min {recons.min()}, max {recons.max()}, mean {recons.mean()}") # 2.2 perceptual loss if self.perceptual_weight > 0: p_loss = self.perceptual_loss(x, recons) rec_loss = rec_loss + self.perceptual_weight * p_loss - if self.print_losses: - print( - f"rec_loss {(rec_loss.sum()/rec_loss.shape[0]).asnumpy()}, " - + f"p_loss: {(p_loss.sum()/p_loss.shape[0]).asnumpy()}" - if self.perceptual_weight > 0 - else "" - ) + if self.print_losses: + print( + f"new rec_loss {(rec_loss.sum()/rec_loss.shape[0]).asnumpy()}, " + + f"p_loss: {(p_loss.sum()/p_loss.shape[0]).asnumpy()}" + ) + # if (p_loss.sum()/p_loss.shape[0]).asnumpy()< 0: + # print("negative loss!") + # print(f"x min {x.min()}, max {x.max()}, mean {x.mean()}") + # print(f"recons min {recons.min()}, max {recons.max()}, mean {recons.mean()}") nll_loss = rec_loss / mint.exp(self.logvar) + self.logvar if weights is not None: From 23598dfa3e0ce1b19d59fee362fdf4246981b5c5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:54:03 +0800 Subject: [PATCH 054/133] set norm_dtype bf16 --- examples/opensora_pku/opensora/npu_config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py index ae3b3b7806..2a37bd203d 100644 --- a/examples/opensora_pku/opensora/npu_config.py +++ b/examples/opensora_pku/opensora/npu_config.py @@ -45,6 +45,7 @@ def __init__(self): self.replaced_type = ms.float32 self.conv_dtype = ms.bfloat16 # FIXME: torch uses float16 + self.norm_dtype = ms.bfloat16 # use bf16 for group_norm, layer_norm and batch_norm. Set to fp32 when training if self.enable_FA and self.enable_FP32: self.inf_float = -10000.0 else: @@ -138,13 +139,13 @@ def _run(self, operator, x, tmp_dtype, out_dtype=None): return operator(x) def run_group_norm(self, operator, x): - return self._run(operator, x, ms.float32) + return self._run(operator, x, self.norm_dtype) def run_layer_norm(self, operator, x): - return self._run(operator, x, ms.float32) + return self._run(operator, x, self.norm_dtype) def run_batch_norm(self, operator, x): - return self._run(operator, x, ms.float32) + return self._run(operator, x, self.norm_dtype) def run_conv3d(self, operator, x, out_dtype): return self._run(operator, x, self.conv_dtype, out_dtype) From 27bbd072535193ac6efcdcd4a43f1cde0f536200 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:15:15 +0800 Subject: [PATCH 055/133] correct save video bug --- examples/opensora_pku/opensora/sample/rec_video.py | 2 +- examples/opensora_pku/opensora/utils/video_utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/sample/rec_video.py b/examples/opensora_pku/opensora/sample/rec_video.py index a6fbc6c244..530e5a9fe9 100644 --- a/examples/opensora_pku/opensora/sample/rec_video.py +++ b/examples/opensora_pku/opensora/sample/rec_video.py @@ -25,7 +25,6 @@ mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) from mindone.utils.logger import set_logger -from mindone.visualize.videos import save_videos sys.path.append(".") from functools import partial @@ -36,6 +35,7 @@ from opensora.models.causalvideovae import ae_wrapper from opensora.npu_config import npu_config from opensora.utils.utils import get_precision +from opensora.utils.video_utils import save_videos logger = logging.getLogger(__name__) diff --git a/examples/opensora_pku/opensora/utils/video_utils.py b/examples/opensora_pku/opensora/utils/video_utils.py index 65b2edccfa..e8d73b9a1a 100644 --- a/examples/opensora_pku/opensora/utils/video_utils.py +++ b/examples/opensora_pku/opensora/utils/video_utils.py @@ -27,7 +27,8 @@ def create_video_from_rgb_numpy_arrays(image_arrays, output_file, fps: Union[int # Write each frame to the video for img in image_arrays: - video_writer.write(img) + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + video_writer.write(img_bgr) # Release the VideoWriter video_writer.release() From e63bc102caa811b4f3343f1ebbb6fc6df92a789a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:50:09 +0800 Subject: [PATCH 056/133] print ops data types --- examples/opensora_pku/opensora/npu_config.py | 18 ++++++++---------- .../opensora/train/train_causalvae.py | 2 ++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py index 2a37bd203d..9215cb748d 100644 --- a/examples/opensora_pku/opensora/npu_config.py +++ b/examples/opensora_pku/opensora/npu_config.py @@ -5,6 +5,7 @@ import subprocess from contextlib import contextmanager +import pandas as pd from opensora.utils.ms_utils import init_env import mindspore as ms @@ -64,6 +65,13 @@ def __init__(self): self.FA_dtype = ms.bfloat16 assert self.FA_dtype in [ms.float16, ms.bfloat16], f"Unsupported flash-attention dtype: {self.FA_dtype}" + def print_ops_dtype_info(self): + # print data types for some key operators + headers = ["Conv3D dtype", "FA dtype", "Norm dtype", "Interpolate, AvgPool"] + values = [[str(self.conv_dtype), str(self.FA_dtype), str(self.norm_dtype), str(self.replaced_type)]] + df = pd.DataFrame(values, columns=headers) + print(df) + def set_npu_env(self, args): rank_id, device_num = init_env( mode=args.mode, @@ -160,16 +168,6 @@ def run_pool_2d(self, operator, x, kernel_size, stride): x = operator(x, kernel_size=kernel_size, stride=stride) return x - def run_pad_2d(self, operator, x, pad, mode="constant"): - if self.on_npu: - x_dtype = x.dtype - x = x.to(self.replaced_type) - x = operator(x, pad, mode) - x = x.to(x_dtype) - else: - x = operator(x, pad, mode) - return x - def run_interpolate(self, operator, x, scale_factor=None): if self.on_npu: x_dtype = x.dtype diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index f4eaccad87..48331be9a3 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -62,6 +62,8 @@ def set_modules_requires_grad(modules, requires_grad): def main(args): # 1. init rank_id, device_num = npu_config.set_npu_env(args) + npu_config.norm_dtype = ms.float32 # to train causal vae, set norm dtype to fp32 + npu_config.print_ops_dtype_info() dtype = get_precision(args.precision) if args.exp_name is not None and len(args.exp_name) > 0: From 1a5e9eb29e4c28e3693213547b4ea2e8ee22ba03 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Nov 2024 11:55:46 +0800 Subject: [PATCH 057/133] resize nearest neighbor to npu_config.run --- .../causalvideovae/model/modules/updownsample.py | 12 +++++++++--- examples/opensora_pku/opensora/npu_config.py | 6 +++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py index 3e7c7626e4..679f280940 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py @@ -11,6 +11,11 @@ from .ops import cast_tuple, video_to_image +class ResizeNearestNeighbor(nn.Cell): + def construct(self, x, size, scale_factor=None): + return ops.interpolate(x, size=size, scale_factor=scale_factor, mode="nearest") + + class Upsample(nn.Cell): def __init__(self, in_channels, out_channels, with_conv=True, dtype=ms.float32): super().__init__() @@ -28,12 +33,13 @@ def __init__(self, in_channels, out_channels, with_conv=True, dtype=ms.float32): weight_init=HeUniform(negative_slope=math.sqrt(5)), bias_init=Uniform(scale=1 / math.sqrt(in_channels * 3 * 3)), ).to_float(self.dtype) + self.resize = ResizeNearestNeighbor() @video_to_image def construct(self, x): in_shape = x.shape[-2:] out_shape = tuple(2 * x for x in in_shape) - x = ops.ResizeNearestNeighbor(out_shape)(x) + x = npu_config.run_interpolate(self.resize, x, size=out_shape) if self.with_conv: x = self.conv(x) return x @@ -298,8 +304,8 @@ def construct(self, x): class TrilinearInterpolate(nn.Cell): - def construct(self, x, scale_factor): - return ops.interpolate(x, scale_factor=scale_factor, mode="trilinear") + def construct(self, x, scale_factor, size=None): + return ops.interpolate(x, scale_factor=scale_factor, size=size, mode="trilinear") class Spatial2xTime2x3DUpsample(nn.Cell): diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py index 9215cb748d..008cc8b8d7 100644 --- a/examples/opensora_pku/opensora/npu_config.py +++ b/examples/opensora_pku/opensora/npu_config.py @@ -168,14 +168,14 @@ def run_pool_2d(self, operator, x, kernel_size, stride): x = operator(x, kernel_size=kernel_size, stride=stride) return x - def run_interpolate(self, operator, x, scale_factor=None): + def run_interpolate(self, operator, x, size=None, scale_factor=None): if self.on_npu: x_dtype = x.dtype x = x.to(self.replaced_type) - x = operator(x, scale_factor=scale_factor) + x = operator(x, size=size, scale_factor=scale_factor) x = x.to(x_dtype) else: - x = operator(x, scale_factor=scale_factor) + x = operator(x, size=size, scale_factor=scale_factor) return x def run_attention(self, query, key, value, attention_mask, input_layout, head_dim, head_num): From e40bcb63f94483075989f5c65b36ce66a4032106 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Nov 2024 14:40:56 +0800 Subject: [PATCH 058/133] print ops info --- examples/opensora_pku/examples/rec_image.py | 1 + examples/opensora_pku/examples/rec_video.py | 1 + .../opensora_pku/examples/rec_video_folder.py | 1 + .../opensora_pku/opensora/sample/rec_image.py | 1 + .../opensora_pku/opensora/sample/rec_video.py | 1 + .../opensora_pku/opensora/sample/sample.py | 35 +++++++++++-------- 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/examples/opensora_pku/examples/rec_image.py b/examples/opensora_pku/examples/rec_image.py index d7f64a2d8d..c3acfff217 100644 --- a/examples/opensora_pku/examples/rec_image.py +++ b/examples/opensora_pku/examples/rec_image.py @@ -65,6 +65,7 @@ def main(args): image_path = args.image_path short_size = args.short_size npu_config.set_npu_env(args) + npu_config.print_ops_dtype_info() set_logger(name="", output_dir=args.output_path, rank=0) dtype = get_precision(args.precision) diff --git a/examples/opensora_pku/examples/rec_video.py b/examples/opensora_pku/examples/rec_video.py index d1eea5234d..28334b055d 100644 --- a/examples/opensora_pku/examples/rec_video.py +++ b/examples/opensora_pku/examples/rec_video.py @@ -104,6 +104,7 @@ def transform_to_rgb(x, rescale_to_uint8=True): def main(args): npu_config.set_npu_env(args) + npu_config.print_ops_dtype_info() dtype = get_precision(args.precision) set_logger(name="", output_dir=args.output_path, rank=0) if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): diff --git a/examples/opensora_pku/examples/rec_video_folder.py b/examples/opensora_pku/examples/rec_video_folder.py index 727116a6bd..9f6195ee96 100644 --- a/examples/opensora_pku/examples/rec_video_folder.py +++ b/examples/opensora_pku/examples/rec_video_folder.py @@ -46,6 +46,7 @@ def main(args): num_workers = args.num_workers assert args.dataset_name == "video", "Only support video reconstruction!" rank_id, device_num = npu_config.set_npu_env(args) + npu_config.print_ops_dtype_info() dtype = get_precision(args.precision) if not os.path.exists(args.generated_video_dir): diff --git a/examples/opensora_pku/opensora/sample/rec_image.py b/examples/opensora_pku/opensora/sample/rec_image.py index 6d5984d962..ed0e84e8b0 100644 --- a/examples/opensora_pku/opensora/sample/rec_image.py +++ b/examples/opensora_pku/opensora/sample/rec_image.py @@ -63,6 +63,7 @@ def main(args): image_path = args.image_path short_size = args.short_size npu_config.set_npu_env(args) + npu_config.print_ops_dtype_info() set_logger(name="", output_dir=args.output_path, rank=0) dtype = get_precision(args.precision) diff --git a/examples/opensora_pku/opensora/sample/rec_video.py b/examples/opensora_pku/opensora/sample/rec_video.py index 530e5a9fe9..17abb677fa 100644 --- a/examples/opensora_pku/opensora/sample/rec_video.py +++ b/examples/opensora_pku/opensora/sample/rec_video.py @@ -114,6 +114,7 @@ def transform_to_rgb(x, rescale_to_uint8=True): def main(args): npu_config.set_npu_env(args) + npu_config.print_ops_dtype_info() dtype = get_precision(args.precision) set_logger(name="", output_dir=args.output_path, rank=0) if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): diff --git a/examples/opensora_pku/opensora/sample/sample.py b/examples/opensora_pku/opensora/sample/sample.py index 1320dd90f2..8e0fd6621a 100644 --- a/examples/opensora_pku/opensora/sample/sample.py +++ b/examples/opensora_pku/opensora/sample/sample.py @@ -1,4 +1,6 @@ -import os, sys +import os +import sys + # TODO: remove in future when mindone is ready for install mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) @@ -8,13 +10,13 @@ import time from opensora.npu_config import npu_config -from opensora.utils.sample_utils import ( - prepare_pipeline, get_args, run_model_and_save_samples -) -# from opensora.sample.caption_refiner import OpenSoraCaptionRefiner +from opensora.utils.sample_utils import get_args, prepare_pipeline, run_model_and_save_samples from mindone.utils.logger import set_logger +# from opensora.sample.caption_refiner import OpenSoraCaptionRefiner + + logger = logging.getLogger(__name__) if __name__ == "__main__": @@ -22,30 +24,33 @@ save_dir = args.save_img_path os.makedirs(save_dir, exist_ok=True) set_logger(name="", output_dir=save_dir) - + # 1. init environment rank_id, device_num = npu_config.set_npu_env(args) - + npu_config.print_ops_dtype_info() + # 2. build models and pipeline - if args.num_frames != 1 and args.enhance_video is not None: #TODO + if args.num_frames != 1 and args.enhance_video is not None: # TODO from opensora.sample.VEnhancer.enhance_a_video import VEnhancer - enhance_video_model = VEnhancer(model_path=args.enhance_video, version='v2', device=args.device) + + enhance_video_model = VEnhancer(model_path=args.enhance_video, version="v2", device=args.device) else: enhance_video_model = None - - pipeline = prepare_pipeline(args) # build I2V/T2V pipeline - - if args.caption_refiner is not None: #TODO: TO TEST + + pipeline = prepare_pipeline(args) # build I2V/T2V pipeline + + if args.caption_refiner is not None: # TODO: TO TEST caption_refiner_model = OpenSoraCaptionRefiner(args.caption_refiner, dtype=ms.float16) else: caption_refiner_model = None # 3. inference start_time = time.time() - run_model_and_save_samples(args, pipeline, rank_id, device_num, save_dir, caption_refiner_model, enhance_video_model) + run_model_and_save_samples( + args, pipeline, rank_id, device_num, save_dir, caption_refiner_model, enhance_video_model + ) end_time = time.time() time_cost = end_time - start_time logger.info(f"Inference time cost: {time_cost:0.3f}s") logger.info(f"Inference speed: {len(args.text_prompt) / time_cost:0.3f} samples/s") logger.info(f"{'latents' if args.save_latents else 'videos' } saved to {save_dir}") - \ No newline at end of file From 7606fc90b73893eb1a0e976fcd661290a92e65bf Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Nov 2024 15:23:30 +0800 Subject: [PATCH 059/133] update dataset file --- examples/opensora_pku/scripts/train_data/video_data_v1_2.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt b/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt index d2853a91d2..af4c65966e 100644 --- a/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt +++ b/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt @@ -1 +1 @@ -/home_host/susan/workspace/datasets/Open-Sora-Plan-v1.2.0,/home_host/susan/workspace/datasets/Open-Sora-Plan-v1.2.0/mixkit_emb-len=512,/home_host/susan/workspace/datasets/Open-Sora-Plan-v1.2.0/v1.1.0_HQ_part1_Traffic_train.json +datasets/Open-Sora-Plan-v1.2.0,datasets/Open-Sora-Plan-v1.2.0/mixkit_emb-len=512,datasets/Open-Sora-Plan-v1.2.0/v1.1.0_HQ_part1_Traffic_train.json From 4c93b492001c5f3bd032c2acd0a65be47dec83e5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Nov 2024 15:39:00 +0800 Subject: [PATCH 060/133] update train t2v script --- examples/opensora_pku/opensora/npu_config.py | 16 +++++++++++--- .../opensora/train/train_t2v_diffusers.py | 22 ++++--------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py index 008cc8b8d7..3aaafa25aa 100644 --- a/examples/opensora_pku/opensora/npu_config.py +++ b/examples/opensora_pku/opensora/npu_config.py @@ -70,16 +70,26 @@ def print_ops_dtype_info(self): headers = ["Conv3D dtype", "FA dtype", "Norm dtype", "Interpolate, AvgPool"] values = [[str(self.conv_dtype), str(self.FA_dtype), str(self.norm_dtype), str(self.replaced_type)]] df = pd.DataFrame(values, columns=headers) + print("VAE operators data types:") print(df) - def set_npu_env(self, args): + def set_npu_env(self, args, strategy_ckpt_save_file=""): rank_id, device_num = init_env( mode=args.mode, - device_target=args.device, + seed=getattr(args, "seed", 42), distributed=getattr(args, "use_parallel", False), - precision_mode=getattr(args, "precision_mode", None), + device_target=getattr(args, "device", "Ascend"), + max_device_memory=getattr(args, "max_device_memory", None), + parallel_mode=getattr(args, "parallel_mode", "data"), + mempool_block_size=getattr(args, "mempool_block_size", "9GB"), + global_bf16=getattr(args, "global_bf16", False), + strategy_ckpt_save_file=strategy_ckpt_save_file, + optimizer_weight_shard_size=getattr(args, "optimizer_weight_shard_size", 8), + sp_size=getattr(args, "sp_size", 1), jit_level=getattr(args, "jit_level", None), + enable_parallel_fusion=getattr(args, "enable_parallel_fusion", False), jit_syntax_level=getattr(args, "jit_syntax_level", "strict"), + comm_fusion=getattr(args, "comm_fusion", False), ) self.rank = rank_id self.bind_thread_to_cpu() diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index efce155a0d..c886c99972 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -23,12 +23,12 @@ from opensora.models.diffusion.common import PatchEmbed2D from opensora.models.diffusion.opensora.modules import Attention, LayerNorm from opensora.models.diffusion.opensora.net_with_loss import DiffusionWithLoss, DiffusionWithLossEval +from opensora.npu_config import npu_config from opensora.train.commons import create_loss_scaler, parse_args from opensora.utils.callbacks import EMAEvalSwapCallback, PerfRecorderCallback from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler from opensora.utils.ema import EMA from opensora.utils.message_utils import print_banner -from opensora.utils.ms_utils import init_env from opensora.utils.utils import get_precision from mindone.diffusers.models.activations import SiLU @@ -74,23 +74,9 @@ def set_all_reduce_fusion( def main(args): # 1. init save_src_strategy = args.use_parallel and args.parallel_mode != "data" - rank_id, device_num = init_env( - args.mode, - seed=args.seed, - distributed=args.use_parallel, - device_target=args.device, - max_device_memory=args.max_device_memory, - parallel_mode=args.parallel_mode, - mempool_block_size=args.mempool_block_size, - global_bf16=args.global_bf16, - strategy_ckpt_save_file=os.path.join(args.output_dir, "src_strategy.ckpt") if save_src_strategy else "", - optimizer_weight_shard_size=args.optimizer_weight_shard_size, - sp_size=args.sp_size if args.num_frames != 1 and args.use_image_num == 0 else 1, - jit_level=args.jit_level, - enable_parallel_fusion=args.enable_parallel_fusion, - jit_syntax_level=args.jit_syntax_level, - comm_fusion=args.comm_fusion, - ) + if args.num_frames == 1 or args.use_image_num != 0: + args.sp_size = 1 + rank_id, device_num = npu_config.set_npu_env(args, strategy_ckpt_save_file=save_src_strategy) set_logger(name="", output_dir=args.output_dir, rank=rank_id, log_level=eval(args.log_level)) # 2. Init and load models From 5b062265cb8a473e0fae1f28e6e43fd7dc926c0a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:18:27 +0800 Subject: [PATCH 061/133] replace ops.cat with mint.cat --- .../opensora/models/diffusion/common.py | 30 +++++++++++-------- .../diffusion/opensora/net_with_loss.py | 8 ++--- .../models/diffusion/utils/pos_embed.py | 12 ++++---- .../models/text_encoder/t5_encoder.py | 4 +-- .../opensora/sample/pipeline_opensora.py | 16 +++++----- .../opensora/train/train_t2v_diffusers.py | 1 + 6 files changed, 38 insertions(+), 33 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/common.py b/examples/opensora_pku/opensora/models/diffusion/common.py index 1e00e98fa8..799c49b8ea 100644 --- a/examples/opensora_pku/opensora/models/diffusion/common.py +++ b/examples/opensora_pku/opensora/models/diffusion/common.py @@ -13,26 +13,30 @@ class PatchEmbed2D(nn.Cell): def __init__( self, - patch_size=16, #2 - in_channels=3, #8 - embed_dim=768, # 24*96=2304 + patch_size=16, # 2 + in_channels=3, # 8 + embed_dim=768, # 24*96=2304 bias=True, ): super().__init__() self.proj = nn.Conv2d( - in_channels, embed_dim, - kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), has_bias=bias, pad_mode="pad" + in_channels, + embed_dim, + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + has_bias=bias, + pad_mode="pad", ) def construct(self, latent): - b, c, t, h, w = latent.shape # b, c=in_channels, t, h, w + b, c, t, h, w = latent.shape # b, c=in_channels, t, h, w # b c t h w -> (b t) c h w - latent = latent.swapaxes(1, 2).reshape(b*t, c, h, w) # b*t, c, h, w + latent = latent.swapaxes(1, 2).reshape(b * t, c, h, w) # b*t, c, h, w latent = self.proj(latent) # b*t, embed_dim, h, w # (b t) c h w -> b (t h w) c _, c, h, w = latent.shape - latent = latent.reshape(b, -1, c, h, w).permute(0, 1, 3, 4, 2).reshape(b, -1, c) # b, t*h*w, embed_dim - + latent = latent.reshape(b, -1, c, h, w).permute(0, 1, 3, 4, 2).reshape(b, -1, c) # b, t*h*w, embed_dim + return latent @@ -81,7 +85,7 @@ def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1), dim_ def get_cos_sin(self, seq_len, interpolation_scale=1): t = ops.arange(seq_len, dtype=self.inv_freq.dtype) / interpolation_scale freqs = ops.outer(t, self.inv_freq).to(self.inv_freq.dtype) - freqs = ops.cat((freqs, freqs), axis=-1) + freqs = mint.cat((freqs, freqs), dim=-1) cos = freqs.cos() # (Seq, Dim) sin = freqs.sin() return cos, sin @@ -89,7 +93,7 @@ def get_cos_sin(self, seq_len, interpolation_scale=1): @staticmethod def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - return ops.cat((-x2, x1), axis=-1) + return mint.cat((-x2, x1), dim=-1) def apply_rope1d(self, tokens, pos1d, cos, sin): assert pos1d.ndim == 2 @@ -124,5 +128,5 @@ def construct(self, tokens, positions): t = self.apply_rope1d(t, poses[0], cos_t.to(tokens.dtype), sin_t.to(tokens.dtype)) y = self.apply_rope1d(y, poses[1], cos_y.to(tokens.dtype), sin_y.to(tokens.dtype)) x = self.apply_rope1d(x, poses[2], cos_x.to(tokens.dtype), sin_x.to(tokens.dtype)) - tokens = ops.cat((t, y, x), axis=-1) - return tokens \ No newline at end of file + tokens = mint.cat((t, y, x), dim=-1) + return tokens diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py index f3fd194fd5..5cca085e13 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py @@ -4,7 +4,7 @@ from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info import mindspore as ms -from mindspore import nn, ops +from mindspore import mint, nn, ops from mindone.diffusers.training_utils import compute_snr @@ -31,7 +31,7 @@ def __init__( noise_scheduler, vae: nn.Cell = None, text_encoder: nn.Cell = None, - text_encoder_2: nn.Cell = None, # not to use yet + text_encoder_2: nn.Cell = None, # not to use yet text_emb_cached: bool = True, video_emb_cached: bool = False, use_image_num: int = 0, @@ -115,7 +115,7 @@ def get_latents(self, x): # (b*f, c, 1, h, w) -> (b*f, c, h, w) -> (b, f, c, h, w) -> (b, c, f, h, w) _, c, _, h, w = images.shape images = images.squeeze(2).reshape(B, self.use_image_num, c, h, w).permute(0, 2, 1, 3, 4) - z = ops.cat([videos, images], axis=2) # b c 16+4, h, w + z = mint.cat([videos, images], dim=2) # b c 16+4, h, w else: raise ValueError("Incorrect Dimensions of x") return z @@ -126,7 +126,7 @@ def construct( attention_mask: ms.Tensor, text_tokens: ms.Tensor, encoder_attention_mask: ms.Tensor = None, - ): # TODO: in the future add 2nd text encoder and tokens + ): # TODO: in the future add 2nd text encoder and tokens """ Video diffusion model forward and loss computation for training diff --git a/examples/opensora_pku/opensora/models/diffusion/utils/pos_embed.py b/examples/opensora_pku/opensora/models/diffusion/utils/pos_embed.py index 06d52b0fa8..faea683ac9 100644 --- a/examples/opensora_pku/opensora/models/diffusion/utils/pos_embed.py +++ b/examples/opensora_pku/opensora/models/diffusion/utils/pos_embed.py @@ -1,5 +1,5 @@ import mindspore as ms -from mindspore import nn, ops +from mindspore import mint, nn, ops # ---------------------------------------------------------- # RoPE2D: RoPE implementation in 2D @@ -26,7 +26,7 @@ def get_cos_sin(self, D, seq_len, dtype): t = ops.arange(0, seq_len, dtype=inv_freq.dtype) # freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) freqs = ops.outer(t, inv_freq).to(dtype) - freqs = ops.cat((freqs, freqs), axis=-1) + freqs = mint.cat((freqs, freqs), dim=-1) cos = freqs.cos() # (Seq, Dim) sin = freqs.sin() self.cache[D, seq_len, dtype] = (cos, sin) @@ -35,7 +35,7 @@ def get_cos_sin(self, D, seq_len, dtype): @staticmethod def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - return ops.cat((-x2, x1), axis=-1) + return mint.cat((-x2, x1), dim=-1) def apply_rope1d(self, tokens, pos1d, cos, sin): assert pos1d.ndim == 2 @@ -60,7 +60,7 @@ def construct(self, tokens, positions): y, x = tokens.chunk(2, axis=-1) y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) - tokens = ops.cat((y, x), axis=-1) + tokens = mint.cat((y, x), dim=-1) return tokens @@ -97,7 +97,7 @@ def get_cos_sin(self, D, seq_len, dtype): t = ops.arange(0, seq_len, dtype=inv_freq.dtype) # freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) freqs = ops.outer(t, inv_freq).to(dtype) - freqs = ops.cat((freqs, freqs), axis=-1) + freqs = mint.cat((freqs, freqs), dim=-1) cos = freqs.cos() # (Seq, Dim) sin = freqs.sin() self.cache[D, seq_len, dtype] = (cos, sin) @@ -106,7 +106,7 @@ def get_cos_sin(self, D, seq_len, dtype): @staticmethod def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - return ops.cat((-x2, x1), axis=-1) + return mint.cat((-x2, x1), dim=-1) def apply_rope1d(self, tokens, pos1d, cos, sin): assert pos1d.ndim == 2 diff --git a/examples/opensora_pku/opensora/models/text_encoder/t5_encoder.py b/examples/opensora_pku/opensora/models/text_encoder/t5_encoder.py index 6541428fc8..33fbb2126f 100644 --- a/examples/opensora_pku/opensora/models/text_encoder/t5_encoder.py +++ b/examples/opensora_pku/opensora/models/text_encoder/t5_encoder.py @@ -1,12 +1,12 @@ import copy import logging +import mint import numpy as np from transformers.models.t5.configuration_t5 import T5Config import mindspore as ms import mindspore.nn as nn -import mindspore.ops as ops from mindone.transformers.activations import ACT2FN from mindone.transformers.modeling_utils import MSPreTrainedModel as PreTrainedModel @@ -235,7 +235,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): if key_value_states is None: # self-attn # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = ops.cat([past_key_value, hidden_states], axis=2) + hidden_states = mint.cat([past_key_value, hidden_states], dim=2) elif past_key_value.shape[2] != key_value_states.shape[1]: # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning diff --git a/examples/opensora_pku/opensora/sample/pipeline_opensora.py b/examples/opensora_pku/opensora/sample/pipeline_opensora.py index 372c9f027c..1fab66343b 100644 --- a/examples/opensora_pku/opensora/sample/pipeline_opensora.py +++ b/examples/opensora_pku/opensora/sample/pipeline_opensora.py @@ -638,11 +638,11 @@ def __call__( # 7 create image_rotary_emb, style embedding & time ids if self.do_classifier_free_guidance: - prompt_embeds = ops.cat([negative_prompt_embeds, prompt_embeds], axis=0) - prompt_attention_mask = ops.cat([negative_prompt_attention_mask, prompt_attention_mask], axis=0) + prompt_embeds = mint.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = mint.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) if self.tokenizer_2 is not None: - prompt_embeds_2 = ops.cat([negative_prompt_embeds_2, prompt_embeds_2], axis=0) - prompt_attention_mask_2 = ops.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2], axis=0) + prompt_embeds_2 = mint.cat([negative_prompt_embeds_2, prompt_embeds_2], dim=0) + prompt_attention_mask_2 = mint.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2], dim=0) # ==================make sp===================================== if get_sequence_parallel_state(): @@ -656,7 +656,7 @@ def __call__( latents, temp_attention_mask = self.prepare_parallel_latent(latents) temp_attention_mask = ( - ops.cat([temp_attention_mask] * 2) + mint.cat([temp_attention_mask] * 2) if (self.do_classifier_free_guidance and temp_attention_mask is not None) else temp_attention_mask ) @@ -677,7 +677,7 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = ops.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = mint.cat([latents] * 2) if self.do_classifier_free_guidance else latents if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -766,7 +766,7 @@ def __call__( # all_latents = ops.zeros(full_shape, dtype=latents.dtype) all_latents = self.all_gather(latents) latents_list = mint.chunk(all_latents, world_size, axis=0) - latents = ops.cat(latents_list, axis=2) + latents = mint.cat(latents_list, dim=2) # ==================make sp===================================== if not output_type == "latents": @@ -815,5 +815,5 @@ def decode_latents(self, latents): bs = latents.shape[0] for i in range(bs): out.append(per_sample_func(latents[i : i + 1])) - out = ops.cat(out, axis=0) + out = mint.cat(out, dim=0) return out # b t h w c diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index c886c99972..435572f596 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -77,6 +77,7 @@ def main(args): if args.num_frames == 1 or args.use_image_num != 0: args.sp_size = 1 rank_id, device_num = npu_config.set_npu_env(args, strategy_ckpt_save_file=save_src_strategy) + npu_config.print_ops_dtype_info() set_logger(name="", output_dir=args.output_dir, rank=rank_id, log_level=eval(args.log_level)) # 2. Init and load models From 0623ebc4fb95d60f56cf6b193f9292cca6b2260a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:19:53 +0800 Subject: [PATCH 062/133] remove conflicting requirements --- examples/opensora_pku/requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/opensora_pku/requirements.txt b/examples/opensora_pku/requirements.txt index 6860edc09c..f501cacbc2 100644 --- a/examples/opensora_pku/requirements.txt +++ b/examples/opensora_pku/requirements.txt @@ -14,7 +14,9 @@ omegaconf pyyaml sentencepiece mindnlp==0.4.0 -transformers>=4.46.0 bs4 huggingface_hub>=0.22.2 decord +pillow +tokenizers +transformers From 1ded897b6693a31fc150a5b311ca787e0187ecf7 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:51:05 +0800 Subject: [PATCH 063/133] connect_res_layer_num = 2 --- examples/opensora_pku/scripts/causalvae/wfvae_8dim.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/scripts/causalvae/wfvae_8dim.json b/examples/opensora_pku/scripts/causalvae/wfvae_8dim.json index 995d10e868..1f55754145 100644 --- a/examples/opensora_pku/scripts/causalvae/wfvae_8dim.json +++ b/examples/opensora_pku/scripts/causalvae/wfvae_8dim.json @@ -1,8 +1,8 @@ { "_class_name": "WFVAEModel", - "_diffusers_version": "0.30.2", + "_diffusers_version": "0.28.0", "base_channels": 128, - "connect_res_layer_num": 1, + "connect_res_layer_num": 2, "decoder_energy_flow_hidden_size": 128, "decoder_num_resblocks": 2, "dropout": 0.0, From 3b24f409862ba23bde6b42c27b3903dd8fc963cd Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 26 Nov 2024 17:24:59 +0800 Subject: [PATCH 064/133] add new single-device training script --- .../single-device/sample_t2i_1x480p.sh | 19 ----- .../single-device/train_debug.sh | 72 ------------------- .../single-device/train_image3d_480p.sh | 43 ----------- .../single-device/train_t2v_stage1.sh | 58 +++++++++++++++ .../single-device/train_t2v_stage2.sh | 58 +++++++++++++++ .../single-device/train_t2v_stage3.sh | 58 +++++++++++++++ .../single-device/train_video3d_nx480p.sh | 49 ------------- .../single-device/train_video3d_nx720p.sh | 49 ------------- 8 files changed, 174 insertions(+), 232 deletions(-) delete mode 100644 examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x480p.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/single-device/train_debug.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/single-device/train_image3d_480p.sh create mode 100644 examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh create mode 100644 examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh create mode 100644 examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/single-device/train_video3d_nx480p.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/single-device/train_video3d_nx720p.sh diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x480p.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x480p.sh deleted file mode 100644 index b4db96741d..0000000000 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x480p.sh +++ /dev/null @@ -1,19 +0,0 @@ -export DEVICE_ID=0 -python opensora/sample/sample_t2v.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.2.0/1x480p \ - --num_frames 1 \ - --height 480 \ - --width 640 \ - --cache_dir "./" \ - --text_encoder_name google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae\ - --save_img_path "./sample_videos/prompt_list_0_1x480p" \ - --fps 24 \ - --guidance_scale 4.5 \ - --num_sampling_steps 100 \ - --enable_tiling \ - --max_sequence_length 512 \ - --sample_method EulerAncestralDiscrete \ - --model_type "dit" \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_debug.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_debug.sh deleted file mode 100644 index 6fb00f263f..0000000000 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_debug.sh +++ /dev/null @@ -1,72 +0,0 @@ -''' -Training scheduler -We replaced the eps-pred loss with v-pred loss and enable ZeroSNR. For videos, we resample to 16 FPS for training. - -Stage 1: We initially initialized from the image weights of version 1.2.0 and trained images at a resolution of 1x320x320. The objective of this phase was to fine-tune the 3D dense attention model to a sparse attention model. The entire fine-tuning process involved approximately 100k steps, with a batch size of 1024 and a learning rate of 2e-5. The image data was primarily sourced from SAM in version 1.2.0. - -Stage 2: We trained the model jointly on images and videos, with a maximum resolution of 93x320x320. -The entire fine-tuning process involved approximately 300k steps, with a batch size of 1024 and a learning rate of 2e-5. -The image data was primarily sourced from SAM in version 1.2.0, while the video data consisted of the unfiltered Panda70m. -In fact, the model had nearly converged around 100k steps, and by 300k steps, there were no significant gains. -Subsequently, we performed data cleaning and caption rewriting, with further data analysis discussed at the end. - -Stage 3: We fine-tuned the model using our filtered Panda70m dataset, with a fixed resolution of 93x352x640. The entire fine-tuning process involved approximately 30k steps, with a batch size of 1024 and a learning rate of 1e-5. -''' - -# Stage 2: 93x320x320 -export DEVICE_ID=0 -NUM_FRAME=29 -HEIGHT=320 -WIDTH=320 -python opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V_v1_3-2B/122 \ - --text_encoder_name_1 /home_host/susan/workspace/checkpoints/google/mt5-xxl \ - --cache_dir "./" \ - --dataset t2v \ - --data "scripts/train_data/video_data_v1_2.txt" \ - --ae WFVAEModel_D8_4x8x8 \ - --ae_path /home_host/susan/workspace/checkpoints/LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --sample_rate 1 \ - --num_frames ${NUM_FRAME} \ - --max_height ${HEIGHT} \ - --max_width ${WIDTH} \ - --force_resolution \ - --interpolation_scale_t 1.0 \ - --interpolation_scale_h 1.0 \ - --interpolation_scale_w 1.0 \ - --gradient_checkpointing \ - --train_batch_size=1 \ - --dataloader_num_workers 1 \ - --gradient_accumulation_steps=1 \ - --max_train_steps 1000000 \ - --start_learning_rate=2e-5 \ - --lr_scheduler="constant" \ - --seed=10 \ - --lr_warmup_steps=500 \ - --precision="bf16" \ - --checkpointing_steps=1000 \ - --output_dir="./checkpoints/t2v-${NUM_FRAME}x${HEIGHT}x${WIDTH}/" \ - --model_max_length 512 \ - --use_image_num 0 \ - --cfg 0.1 \ - --snr_gamma 5.0 \ - --use_ema True\ - --ema_start_step 0 \ - --enable_tiling \ - --tile_overlap_factor 0.125 \ - --clip_grad True \ - --max_grad_norm 1.0 \ - --use_rope \ - --noise_offset 0.02 \ - --enable_stable_fp32 True \ - --ema_decay 0.999 \ - --speed_factor 1.0 \ - --drop_short_ratio 1.0 \ - --hw_stride 32 \ - --sparse1d \ - --sparse_n 4 \ - --train_fps 16 \ - --trained_data_global_step 0 \ - --group_data \ - --prediction_type "v_prediction" \ - --mode 1 diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_image3d_480p.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_image3d_480p.sh deleted file mode 100644 index 2ab1fc7cfb..0000000000 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_image3d_480p.sh +++ /dev/null @@ -1,43 +0,0 @@ -# Stage 2: 1x480p, maybe oom on 910* -export DEVICE_ID=0 -python opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V-ROPE-L/122 \ - --text_encoder_name google/mt5-xxl \ - --cache_dir "./" \ - --dataset t2v \ - --data "scripts/train_data/merge_data.txt" \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path "LanguageBind/Open-Sora-Plan-v1.2.0/vae" \ - --sample_rate 1 \ - --num_frames 1 \ - --max_height 480 \ - --max_width 640 \ - --interpolation_scale_t 1.0 \ - --interpolation_scale_h 1.0 \ - --interpolation_scale_w 1.0 \ - --attention_mode xformers \ - --gradient_checkpointing \ - --train_batch_size=8 \ - --dataloader_num_workers 20 \ - --gradient_accumulation_steps=1 \ - --max_train_steps=1000000 \ - --start_learning_rate=1e-4 \ - --lr_scheduler="constant" \ - --seed=10 \ - --lr_warmup_steps=500 \ - --precision="bf16" \ - --checkpointing_steps=2000 \ - --output_dir="t2i-image3d-1x480p/" \ - --model_max_length 512 \ - --use_image_num 0 \ - --snr_gamma 5.0 \ - --use_ema True\ - --ema_start_step 0 \ - --ema_decay 0.999 \ - --enable_tiling \ - --tile_overlap_factor 0.0 \ - --pretrained "path/to/pretrained/1x240p/ckpt" \ - --clip_grad True \ - --max_grad_norm 1.0 \ - --use_rope \ - --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh new file mode 100644 index 0000000000..fe6f01b9d9 --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh @@ -0,0 +1,58 @@ +# Stage 1: 1x320x320 +NUM_FRAME=1 +WIDTH=320 +HEIGHT=320 +python opensora/train/train_t2v_diffusers.py \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 google/mt5-xxl \ + --cache_dir "./" \ + --dataset t2v \ + --data "scripts/train_data/image_data_v1_2.txt" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --sample_rate 1 \ + --num_frames ${NUM_FRAME} \ + --max_height ${HEIGHT} \ + --max_width ${WIDTH} \ + --interpolation_scale_t 1.0 \ + --interpolation_scale_h 1.0 \ + --interpolation_scale_w 1.0 \ + --gradient_checkpointing \ + --train_batch_size=1 \ + --dataloader_num_workers 8 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --start_learning_rate=2e-5 \ + --lr_scheduler="constant" \ + --seed=10 \ + --lr_warmup_steps=0 \ + --precision="bf16" \ + --checkpointing_steps=1000 \ + --output_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/" \ + --model_max_length 512 \ + --use_image_num 0 \ + --cfg 0.1 \ + --snr_gamma 5.0 \ + --rescale_betas_zero_snr \ + --use_ema True\ + --ema_start_step 0 \ + --enable_tiling \ + --tile_overlap_factor 0.125 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --noise_offset 0.02 \ + --enable_stable_fp32 True\ + --ema_decay 0.999 \ + --speed_factor 1.0 \ + --drop_short_ratio 0.0 \ + --max_device_memory "59GB" \ + --jit_syntax_level "lax" \ + --dataset_sink_mode False \ + --prediction_type "v_prediction" \ + --hw_stride 32 \ + --sparse1d \ + --sparse_n 4 \ + --train_fps 16 \ + --trained_data_global_step 0 \ + --group_data \ + --mode 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh new file mode 100644 index 0000000000..94e09d4246 --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh @@ -0,0 +1,58 @@ +# Stage 2: 93x320x320 +NUM_FRAME=93 +WIDTH=320 +HEIGHT=320 +python opensora/train/train_t2v_diffusers.py \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 google/mt5-xxl \ + --cache_dir "./" \ + --dataset t2v \ + --data "scripts/train_data/video_data_v1_2.txt" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --sample_rate 1 \ + --num_frames ${NUM_FRAME} \ + --max_height ${HEIGHT} \ + --max_width ${WIDTH} \ + --interpolation_scale_t 1.0 \ + --interpolation_scale_h 1.0 \ + --interpolation_scale_w 1.0 \ + --gradient_checkpointing \ + --train_batch_size=1 \ + --dataloader_num_workers 8 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --start_learning_rate=2e-5 \ + --lr_scheduler="constant" \ + --seed=10 \ + --lr_warmup_steps=0 \ + --precision="bf16" \ + --checkpointing_steps=1000 \ + --output_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/" \ + --model_max_length 512 \ + --use_image_num 0 \ + --cfg 0.1 \ + --snr_gamma 5.0 \ + --rescale_betas_zero_snr \ + --use_ema True\ + --ema_start_step 0 \ + --enable_tiling \ + --tile_overlap_factor 0.125 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --noise_offset 0.02 \ + --enable_stable_fp32 True\ + --ema_decay 0.999 \ + --speed_factor 1.0 \ + --drop_short_ratio 0.0 \ + --max_device_memory "59GB" \ + --jit_syntax_level "lax" \ + --dataset_sink_mode False \ + --prediction_type "v_prediction" \ + --hw_stride 32 \ + --sparse1d \ + --sparse_n 4 \ + --train_fps 16 \ + --trained_data_global_step 0 \ + --group_data \ + --mode 1 diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh new file mode 100644 index 0000000000..b9cd232a32 --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh @@ -0,0 +1,58 @@ +# Stage 3: 93x480x480 (480x480, 640x352, 352x640) +NUM_FRAME=93 +WIDTH=480 +HEIGHT=480 +python opensora/train/train_t2v_diffusers.py \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 google/mt5-xxl \ + --cache_dir "./" \ + --dataset t2v \ + --data "scripts/train_data/video_data_v1_2.txt" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --sample_rate 1 \ + --num_frames ${NUM_FRAME} \ + --max_height ${HEIGHT} \ + --max_width ${WIDTH} \ + --interpolation_scale_t 1.0 \ + --interpolation_scale_h 1.0 \ + --interpolation_scale_w 1.0 \ + --gradient_checkpointing \ + --train_batch_size=1 \ + --dataloader_num_workers 8 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --start_learning_rate=1e-5 \ + --lr_scheduler="constant" \ + --seed=10 \ + --lr_warmup_steps=0 \ + --precision="bf16" \ + --checkpointing_steps=1000 \ + --output_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/" \ + --model_max_length 512 \ + --use_image_num 0 \ + --cfg 0.1 \ + --snr_gamma 5.0 \ + --rescale_betas_zero_snr \ + --use_ema True\ + --ema_start_step 0 \ + --enable_tiling \ + --tile_overlap_factor 0.125 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --noise_offset 0.02 \ + --enable_stable_fp32 True\ + --ema_decay 0.999 \ + --speed_factor 1.0 \ + --drop_short_ratio 0.0 \ + --max_device_memory "59GB" \ + --jit_syntax_level "lax" \ + --dataset_sink_mode False \ + --prediction_type "v_prediction" \ + --hw_stride 32 \ + --sparse1d \ + --sparse_n 4 \ + --train_fps 16 \ + --trained_data_global_step 0 \ + --group_data \ + --mode 1 diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_video3d_nx480p.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_video3d_nx480p.sh deleted file mode 100644 index f56ac27d0b..0000000000 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_video3d_nx480p.sh +++ /dev/null @@ -1,49 +0,0 @@ -# Stage 3: 29x480p, maybe oom on 910* -export DEVICE_ID=0 -NUM_FRAME=29 -python opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V-ROPE-L/122 \ - --text_encoder_name google/mt5-xxl \ - --cache_dir "./" \ - --dataset t2v \ - --data "scripts/train_data/merge_data_mixkit.txt" \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path "LanguageBind/Open-Sora-Plan-v1.2.0/vae" \ - --sample_rate 1 \ - --num_frames ${NUM_FRAME} \ - --max_height 480 \ - --max_width 640 \ - --interpolation_scale_t 1.0 \ - --interpolation_scale_h 1.0 \ - --interpolation_scale_w 1.0 \ - --attention_mode xformers \ - --gradient_checkpointing \ - --train_batch_size=1 \ - --dataloader_num_workers 8 \ - --gradient_accumulation_steps=1 \ - --max_train_steps=1000000 \ - --start_learning_rate=1e-4 \ - --lr_scheduler="constant" \ - --seed=10 \ - --lr_warmup_steps=500 \ - --precision="bf16" \ - --checkpointing_steps=1000 \ - --output_dir="t2v-video3d-${NUM_FRAME}x480p/" \ - --model_max_length 512 \ - --use_image_num 0 \ - --cfg 0.1 \ - --snr_gamma 5.0 \ - --use_ema True\ - --ema_start_step 0 \ - --enable_tiling \ - --tile_overlap_factor 0.125 \ - --clip_grad True \ - --max_grad_norm 1.0 \ - --use_rope \ - --noise_offset 0.02 \ - --enable_stable_fp32 True\ - --ema_decay 0.999 \ - --speed_factor 1.0 \ - --drop_short_ratio 1.0 \ - --pretrained "LanguageBind/Open-Sora-Plan-v1.2.0/1x480p" \ - # --group_frame \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_video3d_nx720p.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_video3d_nx720p.sh deleted file mode 100644 index 0da3eb287f..0000000000 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_video3d_nx720p.sh +++ /dev/null @@ -1,49 +0,0 @@ -# Stage 4: 29x720p, maybe oom on 910* -export DEVICE_ID=0 -NUM_FRAME=29 -python opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V-ROPE-L/122 \ - --text_encoder_name google/mt5-xxl \ - --cache_dir "./" \ - --dataset t2v \ - --data "scripts/train_data/merge_data_mixkit.txt" \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path "LanguageBind/Open-Sora-Plan-v1.2.0/vae" \ - --sample_rate 1 \ - --num_frames ${NUM_FRAME} \ - --max_height 720 \ - --max_width 1280 \ - --interpolation_scale_t 1.0 \ - --interpolation_scale_h 1.5 \ - --interpolation_scale_w 2.0 \ - --attention_mode xformers \ - --gradient_checkpointing \ - --train_batch_size=1 \ - --dataloader_num_workers 8 \ - --gradient_accumulation_steps=1 \ - --max_train_steps=1000000 \ - --start_learning_rate=1e-4 \ - --lr_scheduler="constant" \ - --seed=10 \ - --lr_warmup_steps=500 \ - --precision="bf16" \ - --checkpointing_steps=1000 \ - --output_dir="t2v-video3d-${NUM_FRAME}x720p/" \ - --model_max_length 512 \ - --use_image_num 0 \ - --cfg 0.1 \ - --snr_gamma 5.0 \ - --use_ema True\ - --ema_start_step 0 \ - --enable_tiling \ - --tile_overlap_factor 0.125 \ - --clip_grad True \ - --max_grad_norm 1.0 \ - --use_rope \ - --noise_offset 0.02 \ - --enable_stable_fp32 True\ - --ema_decay 0.999 \ - --speed_factor 1.0 \ - --drop_short_ratio 1.0 \ - --pretrained "LanguageBind/Open-Sora-Plan-v1.2.0/29x480p" \ - # --group_frame \ From 219b66ec6eb65e2dcf71d29e40248f6c030a58be Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:04:53 +0800 Subject: [PATCH 065/133] no_grad in train pipelines --- .../diffusion/opensora/net_with_loss.py | 51 +++++++++++++------ 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py index 5cca085e13..cc4665bff0 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py @@ -4,7 +4,7 @@ from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info import mindspore as ms -from mindspore import mint, nn, ops +from mindspore import _no_grad, mint, nn, ops from mindone.diffusers.training_utils import compute_snr @@ -13,6 +13,25 @@ logger = logging.getLogger(__name__) +@ms.jit_class +class no_grad(_no_grad): + """ + A context manager that suppresses gradient memory allocation in PyNative mode. + """ + + def __init__(self): + super().__init__() + self._pynative = ms.get_context("mode") == ms.PYNATIVE_MODE + + def __enter__(self): + if self._pynative: + super().__enter__() + + def __exit__(self, *args): + if self._pynative: + super().__exit__(*args) + + class DiffusionWithLoss(nn.Cell): """An training pipeline for diffusion model @@ -145,14 +164,15 @@ def construct( """ # 1. get image/video latents z using vae x = x.to(self.dtype) - if not self.video_emb_cached: - x = ops.stop_gradient(self.get_latents(x)) + with no_grad(): + if not self.video_emb_cached: + x = ops.stop_gradient(self.get_latents(x)) - # 2. get conditions - if not self.text_emb_cached: - text_embed = ops.stop_gradient(self.get_condition_embeddings(text_tokens, encoder_attention_mask)) - else: - text_embed = text_tokens + # 2. get conditions + if not self.text_emb_cached: + text_embed = ops.stop_gradient(self.get_condition_embeddings(text_tokens, encoder_attention_mask)) + else: + text_embed = text_tokens loss = self.compute_loss(x, attention_mask, text_embed, encoder_attention_mask) return loss @@ -276,14 +296,15 @@ def construct( """ # 1. get image/video latents z using vae x = x.to(self.dtype) - if not self.video_emb_cached: - x = ops.stop_gradient(self.get_latents(x)) + with no_grad(): + if not self.video_emb_cached: + x = ops.stop_gradient(self.get_latents(x)) - # 2. get conditions - if not self.text_emb_cached: - text_embed = ops.stop_gradient(self.get_condition_embeddings(text_tokens, encoder_attention_mask)) - else: - text_embed = text_tokens + # 2. get conditions + if not self.text_emb_cached: + text_embed = ops.stop_gradient(self.get_condition_embeddings(text_tokens, encoder_attention_mask)) + else: + text_embed = text_tokens loss, model_pred, target = self.compute_loss(x, attention_mask, text_embed, encoder_attention_mask) return loss, model_pred, target From 972c630389fabd64a2f59f377ad2a4c8a4b41fc5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:06:48 +0800 Subject: [PATCH 066/133] remove vae auto-mixed precision --- .../opensora/train/train_t2v_diffusers.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 435572f596..130b21e0c9 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -18,7 +18,6 @@ from opensora.dataset import getdataset from opensora.dataset.loader import create_dataloader from opensora.models.causalvideovae import ae_channel_config, ae_stride_config, ae_wrapper -from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate from opensora.models.diffusion import Diffusion_models from opensora.models.diffusion.common import PatchEmbed2D from opensora.models.diffusion.opensora.modules import Attention, LayerNorm @@ -98,16 +97,6 @@ def main(args): "dtype": vae_dtype, } vae = ae_wrapper[args.ae](args.ae_path, **kwarg) - # vae.vae_scale_factor = ae_stride_config[args.ae] - - if vae_dtype == ms.float16: - custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] - else: - custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] - logger.info( - f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}" - ) - vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells) vae.set_train(False) for param in vae.get_parameters(): # freeze vae From 227f2ecd1944e016cfe0976420c4ab8114f3f322 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:13:38 +0800 Subject: [PATCH 067/133] allow memory_profile --- examples/opensora_pku/opensora/train/commons.py | 3 ++- examples/opensora_pku/opensora/train/train_t2v_diffusers.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/opensora_pku/opensora/train/commons.py b/examples/opensora_pku/opensora/train/commons.py index ef18fc4a1a..b4e49493c2 100644 --- a/examples/opensora_pku/opensora/train/commons.py +++ b/examples/opensora_pku/opensora/train/commons.py @@ -204,7 +204,8 @@ def parse_train_args(parser): help="whether to compute the validation set loss during training", ) parser.add_argument("--val_interval", default=1, type=int, help="Validation frequency in epochs") - parser.add_argument("--profile", default=False, type=str2bool, help="Profile or not") + parser.add_argument("--profile", default=False, type=str2bool, help="Profile time analysis or not") + parser.add_argument("--profile_memory", default=False, type=str2bool, help="Profile memory analysis or not") parser.add_argument( "--log_level", type=str, diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 130b21e0c9..82f107a318 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -76,6 +76,12 @@ def main(args): if args.num_frames == 1 or args.use_image_num != 0: args.sp_size = 1 rank_id, device_num = npu_config.set_npu_env(args, strategy_ckpt_save_file=save_src_strategy) + if args.mode == 1: + ms.context.set_context(pynative_synchronize=True) + if args.profile_memory: + profiler = ms.Profiler(output_path="./mem_info", profile_memory=True) + ms.context.set_context(memory_optimize_level="O0") + logger.info(f"Memory profiling: {profiler}") npu_config.print_ops_dtype_info() set_logger(name="", output_dir=args.output_dir, rank=rank_id, log_level=eval(args.log_level)) From abcab7a8239fb8970711516019581906f27a8e81 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:17:08 +0800 Subject: [PATCH 068/133] update rec_video_folder --- .../scripts/causalvae/rec_video_folder.sh | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh index 7b88f5c63a..002260191d 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh +++ b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh @@ -1,13 +1,14 @@ python examples/rec_video_folder.py \ --batch_size 1 \ - --real_video_dir ../test_eval/eyes_test \ - --generated_video_dir ../test_eval/eyes_gen \ + --real_video_dir datasets/UCF-101/ \ + --data_file_path datasets/ucf101_test.csv \ + --generated_video_dir recons/ucf101_test/ \ --device Ascend \ --sample_fps 10 \ --sample_rate 1 \ - --num_frames 65 \ - --height 480 \ - --width 640 \ + --num_frames 25 \ + --height 256 \ + --width 256 \ --num_workers 8 \ --ae "WFVAEModel_D8_4x8x8" \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ From 73da943a831fa6b99fa9bd79a6b5540db2cf95e5 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:28:30 +0800 Subject: [PATCH 069/133] dynamic inputs --- .../opensora/train/train_t2v_diffusers.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 82f107a318..0dfec11790 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -581,6 +581,19 @@ def main(args): ema=ema, ) + # set dynamic inputs + _bs = ms.Symbol(unique=True) + video = ms.Tensor(shape=[_bs, 3, None, None, None], dtype=ms.float32) # (b, c, f, h, w) + attention_mask = ms.Tensor(shape=[_bs, None, None, None], dtype=ms.float32) # (b, f, h, w) + text_tokens = ( + ms.Tensor(shape=[_bs, None, args.model_max_length, None], dtype=ms.float32) + if args.text_embed_cache + else ms.Tensor(shape=[_bs, None, args.model_max_length], dtype=ms.float32) + ) + encoder_attention_mask = ms.Tensor(shape=[_bs, None, args.model_max_length], dtype=ms.uint8) + net_with_grads.set_inputs(video, attention_mask, text_tokens, encoder_attention_mask) + logger.info("Dynamic inputs are initialized for training!") + if not args.global_bf16: model = Model( net_with_grads, From 21ed520d2827dacacd714529abbbc960b4d639fa Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 27 Nov 2024 16:24:01 +0800 Subject: [PATCH 070/133] update profile_memory --- .../opensora_pku/opensora/train/train_t2v_diffusers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 0dfec11790..33949233c5 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -76,11 +76,12 @@ def main(args): if args.num_frames == 1 or args.use_image_num != 0: args.sp_size = 1 rank_id, device_num = npu_config.set_npu_env(args, strategy_ckpt_save_file=save_src_strategy) - if args.mode == 1: - ms.context.set_context(pynative_synchronize=True) if args.profile_memory: + if args.mode == 1: + # maybe slow + ms.context.set_context(pynative_synchronize=True) profiler = ms.Profiler(output_path="./mem_info", profile_memory=True) - ms.context.set_context(memory_optimize_level="O0") + # ms.context.set_context(memory_optimize_level="O0") # enabling it may consume more memory logger.info(f"Memory profiling: {profiler}") npu_config.print_ops_dtype_info() set_logger(name="", output_dir=args.output_dir, rank=rank_id, log_level=eval(args.log_level)) From c14659cab3ed1e8b5c98cd3a0500210488106381 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 27 Nov 2024 16:26:47 +0800 Subject: [PATCH 071/133] no jit for dit --- .../opensora/models/diffusion/opensora/modeling_opensora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index c1481eafb7..735fa725dd 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -340,7 +340,7 @@ def get_attention_mask(self, attention_mask): attention_mask = attention_mask.to(ms.bool_) # use bool for sdpa return attention_mask - @ms.jit # use graph mode + # @ms.jit # use graph mode def construct( self, hidden_states: ms.Tensor, From cefb859638d72378d5d4ebb135e7d37dd3fd55ee Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:27:30 +0800 Subject: [PATCH 072/133] sequence length can be not divisible by 16 --- examples/opensora_pku/opensora/npu_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py index 3aaafa25aa..0f540c9308 100644 --- a/examples/opensora_pku/opensora/npu_config.py +++ b/examples/opensora_pku/opensora/npu_config.py @@ -232,7 +232,7 @@ def ms_flash_attention( if input_layout not in ["BSH", "BNSD"]: raise ValueError(f"input_layout must be in ['BSH', 'BNSD'], but get {input_layout}.") Bs, query_tokens, inner_dim = query.shape - assert query_tokens % 16 == 0, f"Sequence length of query must be divisible by 16, but got {query_tokens}." + # assert query_tokens % 16 == 0, f"Sequence length of query must be divisible by 16, but got {query_tokens}." key_tokens = key.shape[1] heads = head_num query = query.view(Bs, query_tokens, heads, -1) From d629441453ec1c5f3fc8ac297c72c561e2e2543a Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Fri, 6 Dec 2024 18:05:52 +0800 Subject: [PATCH 073/133] ops.AlltoAll not support bf16 --- .../opensora_pku/opensora/acceleration/communications.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/opensora_pku/opensora/acceleration/communications.py b/examples/opensora_pku/opensora/acceleration/communications.py index 85efe6b3c6..c8026befc0 100644 --- a/examples/opensora_pku/opensora/acceleration/communications.py +++ b/examples/opensora_pku/opensora/acceleration/communications.py @@ -2,6 +2,7 @@ from opensora.acceleration.parallel_states import hccl_info +import mindspore as ms from mindspore import Tensor, mint, nn, ops logger = logging.getLogger(__name__) @@ -18,6 +19,10 @@ def __init__(self, scatter_dim: int, gather_dim: int): # self.alltoall = AlltoAll(split_count=self.sp_size, group=self.spg) def construct(self, input_: Tensor): + origin_dtype = input_.dtype + if input_.dtype == ms.bfloat16: + input_ = input_.to(ms.float32) + scatter_dim, gather_dim, sp_size = self.scatter_dim, self.gather_dim, self.sp_size inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size @@ -44,7 +49,7 @@ def construct(self, input_: Tensor): + inp_shape[gather_dim + 1 :] ) - return output + return output.to(origin_dtype) class AllGather(nn.Cell): From 634f6c01ffb9d3857bc450a349e50ae53686e5cd Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Mon, 16 Dec 2024 10:37:08 +0800 Subject: [PATCH 074/133] dyn resize --- .../opensora_pku/opensora/dataset/__init__.py | 34 +++++++++---- .../opensora/dataset/transform.py | 48 +++++++++++++++++++ .../diffusion/opensora/modeling_opensora.py | 1 - 3 files changed, 73 insertions(+), 10 deletions(-) diff --git a/examples/opensora_pku/opensora/dataset/__init__.py b/examples/opensora_pku/opensora/dataset/__init__.py index 1289ac8b90..d596a1d9a9 100644 --- a/examples/opensora_pku/opensora/dataset/__init__.py +++ b/examples/opensora_pku/opensora/dataset/__init__.py @@ -6,7 +6,7 @@ from transformers import AutoTokenizer from .t2v_datasets import T2V_dataset -from .transform import TemporalRandomCrop, center_crop_th_tw +from .transform import TemporalRandomCrop, center_crop_th_tw, spatial_stride_crop_video, maxhxw_resize def getdataset(args, dataset_file): @@ -26,14 +26,30 @@ def norm_func_albumentation(image, **kwargs): ), Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]), ] - resize = [ - Lambda( - name="crop_centercrop", - image=partial(center_crop_th_tw, th=args.max_height, tw=args.max_width, top_crop=False), - p=1.0, - ), - Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]), - ] + if args.force_resolution: + assert (args.max_height is not None) and (args.max_width is not None), "set max_height and max_width for fixed resolution" + resize = [ + Lambda( + name="crop_centercrop", + image=partial(center_crop_th_tw, th=args.max_height, tw=args.max_width, top_crop=False), + p=1.0, + ), + Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]), + ] + else: # dynamic resolution + assert args.max_hxw is not None, "set max_hxw for dynamic resolution" + resize = [ + Lambda( + name="maxhxw_resize", + image=partial(maxhxw_resize, max_hxw=args.max_hxw, interpolation_mode=mapping["bilinear"]), + p=1.0, + ), + Lambda( + name="spatial_stride_crop", + image=partial(spatial_stride_crop_video, stride=args.hw_stride), # default stride=32 + p=1.0, + ), + ] transform = Compose( [*resize, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)], diff --git a/examples/opensora_pku/opensora/dataset/transform.py b/examples/opensora_pku/opensora/dataset/transform.py index 1bd3e6cd81..b1edd43fa0 100644 --- a/examples/opensora_pku/opensora/dataset/transform.py +++ b/examples/opensora_pku/opensora/dataset/transform.py @@ -6,6 +6,7 @@ import albumentations import ftfy from bs4 import BeautifulSoup +import numpy as np __all__ = ["create_video_transforms", "t5_text_preprocessing"] @@ -90,6 +91,53 @@ def center_crop_th_tw(image, th, tw, top_crop, **kwargs): cropped_image = crop(image, i, j, new_h, new_w) return cropped_image +def resize(image, h, w, interpolation_mode): + + resize_func = albumentations.Resize(h, w, interpolation = interpolation_mode) + + return resize_func(image=image)["image"] + +def get_params(h, w, stride): + th, tw = h // stride * stride, w // stride * stride + + i = (h - th) // 2 + j = (w - tw) // 2 + + return i, j, th, tw + +def spatial_stride_crop_video(image, stride, **kwargs): + """ + Args: + image (numpy array): Video clip to be cropped. Size is (H, W, C) + Returns: + numpy array: cropped video clip by stride. + size is (OH, OW, C) + """ + h, w = image.shape[:2] + i, j, h, w = get_params(h, w, stride) + return crop(image, i, j, h, w) + +def maxhxw_resize(image, max_hxw, interpolation_mode, **kwargs): + """ + First use the h*w, + then resize to the specified size + Args: + image (numpy array): Video clip to be cropped. Size is (H, W, C) + Returns: + numpy array: scale resized video clip. + """ + h, w = image.shape[:2] + if h * w > max_hxw: + scale_factor = np.sqrt(max_hxw / (h * w)) + tr_h = int(h * scale_factor) + tr_w = int(w * scale_factor) + else: + tr_h = h + tr_w = w + if h == tr_h and w == tr_w: + return image + resize_image = resize(image, tr_h, tr_w, interpolation_mode) + return resize_image # create text transform(preprocess) bad_punct_regex = re.compile( diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 735fa725dd..31edf2913a 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -340,7 +340,6 @@ def get_attention_mask(self, attention_mask): attention_mask = attention_mask.to(ms.bool_) # use bool for sdpa return attention_mask - # @ms.jit # use graph mode def construct( self, hidden_states: ms.Tensor, From fa6f040913933cfa70ae61b822b52fdcbbcfd9c2 Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Mon, 16 Dec 2024 10:39:40 +0800 Subject: [PATCH 075/133] update scripts --- .../text_condition/multi-devices/train_t2v_stage1.sh | 1 + .../text_condition/multi-devices/train_t2v_stage2.sh | 10 +++++----- .../text_condition/multi-devices/train_t2v_stage3.sh | 5 +++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh index e46ce5f429..ce393f2490 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh @@ -14,6 +14,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --sample_rate 1 \ --num_frames ${NUM_FRAME} \ + --force_resolution \ --max_height ${HEIGHT} \ --max_width ${WIDTH} \ --interpolation_scale_t 1.0 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index 7b3bc691bf..21452bbb59 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -1,7 +1,7 @@ -# Stage 2: 93x320x320 +# Stage 2: 93x640x640 NUM_FRAME=93 -WIDTH=320 -HEIGHT=320 +WIDTH=640 +HEIGHT=640 ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ opensora/train/train_t2v_diffusers.py \ @@ -61,5 +61,5 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --trained_data_global_step 0 \ --group_data \ --mode 1 \ - # --sp_size 8 \ - # --train_sp_batch_size 1 \ + --sp_size 8 \ + --train_sp_batch_size 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index 02d71499f9..8cc16c4978 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -1,7 +1,7 @@ # Stage 3: 93x480x480 (480x480, 640x352, 352x640) NUM_FRAME=93 -WIDTH=480 -HEIGHT=480 +WIDTH=640 +HEIGHT=352 ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ opensora/train/train_t2v_diffusers.py \ @@ -14,6 +14,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --sample_rate 1 \ --num_frames ${NUM_FRAME} \ + --force_resolution \ --max_height ${HEIGHT} \ --max_width ${WIDTH} \ --interpolation_scale_t 1.0 \ From 91e075c7de562672608bb3b1b894109f107cd93b Mon Sep 17 00:00:00 2001 From: chenyingshu Date: Mon, 16 Dec 2024 11:02:47 +0800 Subject: [PATCH 076/133] update script --- .../scripts/text_condition/multi-devices/train_t2v_stage2.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index 21452bbb59..988759a925 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -2,6 +2,7 @@ NUM_FRAME=93 WIDTH=640 HEIGHT=640 +MAX_HxW=409600 ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ opensora/train/train_t2v_diffusers.py \ @@ -16,6 +17,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --num_frames ${NUM_FRAME} \ --max_height ${HEIGHT} \ --max_width ${WIDTH} \ + --max_hxw ${MAX_HxW} \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ From 72f9e9cec955a836f7dcef4ed12692c29211bca7 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:04:19 +0800 Subject: [PATCH 077/133] update causal cache --- .../causalvideovae/model/modules/conv.py | 18 ++-- .../model/modules/updownsample.py | 15 +-- .../causalvideovae/model/modules/wavelet.py | 6 +- .../model/vae/modeling_wfvae.py | 97 ++++++++++++------- 4 files changed, 83 insertions(+), 53 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py index bf8a5b34f0..f1244a49c1 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/conv.py @@ -1,4 +1,5 @@ import math +from collections import deque from typing import Tuple, Union import mindspore as ms @@ -115,7 +116,9 @@ def __init__( **kwargs, ) self.enable_cached = enable_cached - self.causal_cached = None + self.is_first_chunk = True + + self.causal_cached = deque() self.cache_offset = 0 def construct(self, x): @@ -123,25 +126,28 @@ def construct(self, x): # x: (bs, Cin, T, H, W ) # first_frame_pad = ops.repeat_interleave(first_frame, (self.time_kernel_size - 1), axis=2) if self.time_kernel_size - 1 > 0: - if self.causal_cached is None: + if self.is_first_chunk: first_frame = x[:, :, :1, :, :] first_frame_pad = mint.cat([first_frame] * (self.time_kernel_size - 1), dim=2) # first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) else: - first_frame_pad = self.causal_cached + first_frame_pad = self.causal_cached.popleft() x = mint.cat((first_frame_pad, x), dim=2) if self.enable_cached and self.time_kernel_size != 1: if (self.time_kernel_size - 1) // self.stride[0] != 0: if self.cache_offset == 0: - self.causal_cached = x[:, :, -(self.time_kernel_size - 1) // self.stride[0] :] + causal_cached = x[:, :, -(self.time_kernel_size - 1) // self.stride[0] :] else: - self.causal_cached = x[:, :, : -self.cache_offset][ + causal_cached = x[:, :, : -self.cache_offset][ :, :, -(self.time_kernel_size - 1) // self.stride[0] : ] else: - self.causal_cached = x[:, :, 0:0, :, :] + causal_cached = x[:, :, 0:0, :, :] + self.causal_cached.append(causal_cached.copy()) + elif self.enable_cached: + self.causal_cached.append(x[:, :, 0:0, :, :].copy()) if npu_config is not None and npu_config.on_npu: return npu_config.run_conv3d(self.conv, x, x_dtype) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py index 679f280940..86c269e064 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/updownsample.py @@ -1,4 +1,5 @@ import math +from collections import deque from typing import Tuple, Union from opensora.npu_config import npu_config @@ -324,19 +325,19 @@ def __init__( self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1) self.interpolate = TrilinearInterpolate() self.enable_cached = enable_cached - self.causal_cached = None + self.causal_cached = deque() def construct(self, x): - if x.shape[2] > 1 or self.causal_cached is not None: - if self.enable_cached and self.causal_cached is not None: - x = mint.cat([self.causal_cached, x], dim=2) - self.causal_cached = x[:, :, -2:-1] + if x.shape[2] > 1 or len(self.causal_cached) > 0: + if self.enable_cached and len(self.causal_cached) > 0: + x = mint.cat([self.causal_cached.popleft(), x], dim=2) + self.causal_cached.append(x[:, :, -2:-1].copy()) x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(2.0, 1.0, 1.0)) x = x[:, :, 2:] x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(1.0, 2.0, 2.0)) else: if self.enable_cached: - self.causal_cached = x[:, :, -1:] + self.causal_cached.append(x[:, :, -1:].copy()) x, x_ = x[:, :, :1], x[:, :, 1:] x_ = npu_config.run_interpolate(self.interpolate, x_, scale_factor=(2.0, 1.0, 1.0)) x_ = npu_config.run_interpolate(self.interpolate, x_, scale_factor=(1.0, 2.0, 2.0)) @@ -344,7 +345,7 @@ def construct(self, x): x = mint.cat([x, x_], dim=2) else: if self.enable_cached: - self.causal_cached = x[:, :, -1:] + self.causal_cached.append(x[:, :, -1:].copy()) x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(1.0, 2.0, 2.0)) return self.conv(x) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py index a5df5fcb13..09b2f937aa 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/modules/wavelet.py @@ -148,7 +148,7 @@ def __init__(self, enable_cached=False, dtype=ms.float32, *args, **kwargs) -> No self.hh_v = Tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 self.gh_v = Tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536 self.enable_cached = enable_cached - self.causal_cached = None + self.is_first_chunk = True self.conv_transpose3d = ops.Conv3DTranspose(1, 1, kernel_size=2, stride=2) def construct(self, coeffs): @@ -185,7 +185,7 @@ def construct(self, coeffs): high_low_high = self.conv_transpose3d(high_low_high, self.g_v) high_high_low = self.conv_transpose3d(high_high_low, self.hh_v) high_high_high = self.conv_transpose3d(high_high_high, self.gh_v) - if self.enable_cached and self.causal_cached: + if self.enable_cached and not self.is_first_chunk: reconstructed = ( low_low_low + low_low_high @@ -207,7 +207,7 @@ def construct(self, coeffs): + high_high_low[:, :, 1:] + high_high_high[:, :, 1:] ) - self.causal_cached = True + reconstructed = reconstructed.reshape(b, -1, *reconstructed.shape[-3:]) return reconstructed.to(input_dtype) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py index aeeda6b452..bfa2ddbf6d 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -1,6 +1,7 @@ import logging import math import os +from collections import deque from typing import List from opensora.npu_config import npu_config @@ -16,7 +17,6 @@ from ..modeling_videobase import VideoBaseAE from ..modules import ( - AttnBlock3DFix, CausalConv3d, Conv2d, HaarWaveletTransform3D, @@ -41,6 +41,7 @@ def __init__( num_resblocks: int = 2, energy_flow_hidden_size: int = 64, dropout: float = 0.0, + attention_type: str = "AttnBlock3DFix", use_attention: bool = True, norm_type: str = "groupnorm", l1_dowmsample_block: str = "Downsample", @@ -145,19 +146,23 @@ def __init__( ), ] if use_attention: - mid_layers.insert(1, AttnBlock3DFix(in_channels=base_channels * 4, norm_type=norm_type, dtype=dtype)) + mid_layers.insert( + 1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type, dtype=dtype) + ) self.mid = nn.SequentialCell(*mid_layers) self.norm_out = Normalize(base_channels * 4, norm_type=norm_type) self.conv_out = CausalConv3d(base_channels * 4, latent_dim * 2, kernel_size=3, stride=1, padding=1) - self.wavelet_tranform_l1 = resolve_str_to_obj(l1_downsample_wavelet)(dtype=dtype) - self.wavelet_tranform_l2 = resolve_str_to_obj(l2_downsample_wavelet)(dtype=dtype) + self.wavelet_transform_in = HaarWaveletTransform3D() + self.wavelet_transform_l1 = resolve_str_to_obj(l1_downsample_wavelet)() + self.wavelet_transform_l2 = resolve_str_to_obj(l2_downsample_wavelet)() - def construct(self, coeffs): + def construct(self, x): + coeffs = self.wavelet_transform_in(x) l1_coeffs = coeffs[:, :3] - l1_coeffs = self.wavelet_tranform_l1(l1_coeffs) + l1_coeffs = self.wavelet_transform_l1(l1_coeffs) l1 = self.connect_l1(l1_coeffs) - l2_coeffs = self.wavelet_tranform_l2(l1_coeffs[:, :3]) + l2_coeffs = self.wavelet_transform_l2(l1_coeffs[:, :3]) l2 = self.connect_l2(l2_coeffs) h = self.down1(coeffs) @@ -173,7 +178,7 @@ def construct(self, coeffs): h = nonlinearity(h) h = self.conv_out(h) - return h + return h, (l1_coeffs, l2_coeffs) class Decoder(VideoBaseAE): @@ -185,6 +190,7 @@ def __init__( num_resblocks: int = 2, dropout: float = 0.0, energy_flow_hidden_size: int = 128, + attention_type: str = "AttnBlock3DFix", use_attention: bool = True, norm_type: str = "groupnorm", t_interpolation: str = "nearest", @@ -217,7 +223,9 @@ def __init__( ), ] if use_attention: - mid_layers.insert(1, AttnBlock3DFix(in_channels=base_channels * 4, norm_type=norm_type, dtype=dtype)) + mid_layers.insert( + 1, resolve_str_to_obj(attention_type)(in_channels=base_channels * 4, norm_type=norm_type, dtype=dtype) + ) self.mid = nn.SequentialCell(*mid_layers) self.up2 = nn.SequentialCell( *[ @@ -283,8 +291,8 @@ def __init__( self.connect_l1 = nn.SequentialCell( *[ ResnetBlock3D( - in_channels=base_channels, - out_channels=base_channels, + in_channels=energy_flow_hidden_size, + out_channels=energy_flow_hidden_size, dropout=dropout, norm_type=norm_type, dtype=dtype, @@ -292,7 +300,7 @@ def __init__( for _ in range(connect_res_layer_num) ], Conv2d( - base_channels, + energy_flow_hidden_size, l1_channels, kernel_size=3, stride=1, @@ -300,14 +308,14 @@ def __init__( pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(base_channels * 3 * 3)), + bias_init=Uniform(scale=1 / math.sqrt(energy_flow_hidden_size * 3 * 3)), ).to_float(dtype), ) self.connect_l2 = nn.SequentialCell( *[ ResnetBlock3D( - in_channels=base_channels, - out_channels=base_channels, + in_channels=energy_flow_hidden_size, + out_channels=energy_flow_hidden_size, dropout=dropout, norm_type=norm_type, dtype=dtype, @@ -315,7 +323,7 @@ def __init__( for _ in range(connect_res_layer_num) ], Conv2d( - base_channels, + energy_flow_hidden_size, 24, kernel_size=3, stride=1, @@ -323,7 +331,7 @@ def __init__( pad_mode="pad", has_bias=True, weight_init=HeUniform(negative_slope=math.sqrt(5)), - bias_init=Uniform(scale=1 / math.sqrt(base_channels * 3 * 3)), + bias_init=Uniform(scale=1 / math.sqrt(energy_flow_hidden_size * 3 * 3)), ).to_float(dtype), ) # Out @@ -340,21 +348,22 @@ def __init__( bias_init=Uniform(scale=1 / math.sqrt(base_channels * 3 * 3)), ).to_float(dtype) - self.inverse_wavelet_tranform_l1 = resolve_str_to_obj(l1_upsample_wavelet)(dtype=dtype) - self.inverse_wavelet_tranform_l2 = resolve_str_to_obj(l2_upsample_wavelet)(dtype=dtype) + self.inverse_wavelet_transform_out = InverseHaarWaveletTransform3D() + self.inverse_wavelet_transform_l1 = resolve_str_to_obj(l1_upsample_wavelet)() + self.inverse_wavelet_transform_l2 = resolve_str_to_obj(l2_upsample_wavelet)() def construct(self, z): h = self.conv_in(z) h = self.mid(h) l2_coeffs = self.connect_l2(h[:, -self.energy_flow_hidden_size :]) - l2 = self.inverse_wavelet_tranform_l2(l2_coeffs) + l2 = self.inverse_wavelet_transform_l2(l2_coeffs) h = self.up2(h[:, : -self.energy_flow_hidden_size]) l1_coeffs = h[:, -self.energy_flow_hidden_size :] l1_coeffs = self.connect_l1(l1_coeffs) l1_coeffs[:, :3] = l1_coeffs[:, :3] + l2 - l1 = self.inverse_wavelet_tranform_l1(l1_coeffs) + l1 = self.inverse_wavelet_transform_l1(l1_coeffs) h = self.up1(h[:, : -self.energy_flow_hidden_size]) @@ -363,7 +372,8 @@ def construct(self, z): h = nonlinearity(h) h = self.conv_out(h) h[:, :3] = h[:, :3] + l1 - return h + dec = self.inverse_wavelet_transform_out(h) + return dec, (l1_coeffs, l2_coeffs) @ModelRegistry.register("WFVAE") @@ -377,6 +387,7 @@ def __init__( encoder_energy_flow_hidden_size: int = 64, decoder_num_resblocks: int = 2, decoder_energy_flow_hidden_size: int = 128, + attention_type: str = "AttnBlock3DFix", use_attention: bool = True, dropout: float = 0.0, norm_type: str = "groupnorm", @@ -400,9 +411,10 @@ def __init__( self.use_tiling = False # Hardcode for now - self.t_chunk_enc = 16 + self.t_chunk_enc = 8 + self.t_chunk_dec = 2 self.t_upsample_times = 4 // 2 - self.t_chunk_dec = 4 + self.use_quant_layer = False self.encoder = Encoder( latent_dim=latent_dim, @@ -416,6 +428,7 @@ def __init__( l1_downsample_wavelet=l1_downsample_wavelet, l2_dowmsample_block=l2_dowmsample_block, l2_downsample_wavelet=l2_downsample_wavelet, + attention_type=attention_type, dtype=dtype, ) self.decoder = Decoder( @@ -432,6 +445,7 @@ def __init__( l1_upsample_wavelet=l1_upsample_wavelet, l2_upsample_block=l2_upsample_block, l2_upsample_wavelet=l2_upsample_wavelet, + attention_type=attention_type, dtype=dtype, ) @@ -474,7 +488,7 @@ def get_decoder(self): def _empty_causal_cached(self, parent): for name, module in parent.cells_and_names(): if hasattr(module, "causal_cached"): - module.causal_cached = None + module.causal_cached = deque() def _set_causal_cached(self, enable_cached=True): for name, module in self.cells_and_names(): @@ -487,6 +501,11 @@ def _set_cache_offset(self, modules, cache_offset=0): if hasattr(submodule, "cache_offset"): submodule.cache_offset = cache_offset + def _set_first_chunk(self, is_first_chunk=True): + for _, module in self.cells_and_names(): + if hasattr(module, "is_first_chunk"): + module.is_first_chunk = is_first_chunk + def build_chunk_start_end(self, t, decoder_mode=False): start_end = [[0, 1]] start = 1 @@ -510,13 +529,13 @@ def encode(self, x, sample_posterior=True): def _encode(self, x): self._empty_causal_cached(self.encoder) - - coeffs = HaarWaveletTransform3D()(x) + self._set_first_chunk(True) if self.use_tiling: - h = self.tile_encode(coeffs) + h = self.tile_encode(x) + # l1, l2 = None, None else: - h = self.encoder(coeffs) + h, _ = self.encoder(x) if self.use_quant_layer: h = self.quant_conv(h) posterior_mean, posterior_logvar = mint.split(h, [h.shape[1] // 2, h.shape[1] // 2], dim=1) @@ -527,26 +546,28 @@ def tile_encode(self, x): start_end = self.build_chunk_start_end(t) result = [] - for start, end in start_end: + for idx, (start, end) in enumerate(start_end): + self._set_first_chunk(idx == 0) chunk = x[:, :, start:end, :, :] - chunk = self.encoder(chunk) + chunk = self.encoder(chunk)[0] if self.use_quant_layer: - chunk = self.encoder(chunk) + chunk = self.quant_conv(chunk) result.append(chunk) return mint.cat(result, dim=2) def decode(self, z): self._empty_causal_cached(self.decoder) + self._set_first_chunk(True) if self.use_tiling: dec = self.tile_decode(z) + # l1, l2 = None, None else: if self.use_quant_layer: z = self.post_quant_conv(z) - dec = self.decoder(z) + dec, _ = self.decoder(z) - dec = InverseHaarWaveletTransform3D()(dec) return dec def tile_decode(self, x): @@ -555,7 +576,9 @@ def tile_decode(self, x): start_end = self.build_chunk_start_end(t, decoder_mode=True) result = [] - for start, end in start_end: + for idx, (start, end) in enumerate(start_end): + self._set_first_chunk(idx == 0) + if end + 1 < t: chunk = x[:, :, start : end + 1, :, :] else: @@ -563,10 +586,10 @@ def tile_decode(self, x): if self.use_quant_layer: chunk = self.post_quant_conv(chunk) - chunk = self.decoder(chunk) + chunk = self.decoder(chunk)[0] if end + 1 < t: - chunk = chunk[:, :, :-2] + chunk = chunk[:, :, :-4] result.append(chunk) else: result.append(chunk) From 28fc60c523785507505407f96ad77c0862bbc066 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:24:45 +0800 Subject: [PATCH 078/133] remove vae amp --- .../opensora_pku/opensora/sample/sample_text_embed.py | 4 ---- examples/opensora_pku/opensora/utils/sample_utils.py | 9 --------- 2 files changed, 13 deletions(-) diff --git a/examples/opensora_pku/opensora/sample/sample_text_embed.py b/examples/opensora_pku/opensora/sample/sample_text_embed.py index 718909ceee..688cd76aec 100644 --- a/examples/opensora_pku/opensora/sample/sample_text_embed.py +++ b/examples/opensora_pku/opensora/sample/sample_text_embed.py @@ -23,10 +23,6 @@ from transformers import AutoTokenizer from mindone.transformers import MT5EncoderModel - -# from mindone.transformers.activations import NewGELUActivation -# from mindone.transformers.models.mt5.modeling_mt5 import MT5LayerNorm -# from mindone.utils.amp import auto_mixed_precision from mindone.utils.config import str2bool from mindone.utils.logger import set_logger diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index 10c92a4aa1..d3e199d0a6 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -12,7 +12,6 @@ from opensora.models.causalvideovae import ae_stride_config, ae_wrapper # from opensora.sample.caption_refiner import OpenSoraCaptionRefiner -from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate from opensora.models.diffusion.common import PatchEmbed2D from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V_v1_3 from opensora.models.diffusion.opensora.modules import Attention, LayerNorm @@ -119,14 +118,6 @@ def prepare_pipeline(args): vae.vae.enable_tiling() vae.vae.tile_overlap_factor = args.tile_overlap_factor - # use amp level O2 for causal 3D VAE with bfloat16 or float16 - if vae_dtype == ms.float16: - custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else [] - else: - custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate] - logger.info(f"Use amp level O2 for causal 3D VAE with dtype={vae_dtype}, custom_fp32_cells: {custom_fp32_cells}") - vae = auto_mixed_precision(vae, amp_level="O2", dtype=vae_dtype, custom_fp32_cells=custom_fp32_cells) - vae.set_train(False) for param in vae.get_parameters(): # freeze vae param.requires_grad = False From c8cb23d377f709a9390526ca656ad70a19cd88ae Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:25:58 +0800 Subject: [PATCH 079/133] remove vae ms_checkpoint in sample.py --- examples/opensora_pku/opensora/utils/sample_utils.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index d3e199d0a6..14ba261a9a 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -98,17 +98,7 @@ def prepare_pipeline(args): # VAE model initiate and weight loading print_banner("vae init") vae_dtype = get_precision(args.vae_precision) - if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint): - logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}") - state_dict = ms.load_checkpoint(args.ms_checkpoint) - # rm 'network.' prefix - state_dict = dict( - [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() - ) - else: - state_dict = None kwarg = { - "state_dict": state_dict, "use_safetensors": True, "dtype": vae_dtype, } From 7b2772c4dbb7ad710a1f444f2fd3e414e185d329 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:30:34 +0800 Subject: [PATCH 080/133] save config.json in train.py --- examples/opensora_pku/opensora/train/train_t2v_diffusers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 33949233c5..679d9f5a95 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -28,7 +28,7 @@ from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler from opensora.utils.ema import EMA from opensora.utils.message_utils import print_banner -from opensora.utils.utils import get_precision +from opensora.utils.utils import get_precision, save_diffusers_json from mindone.diffusers.models.activations import SiLU from mindone.diffusers.schedulers import FlowMatchEulerDiscreteScheduler # CogVideoXDDIMScheduler, @@ -161,7 +161,8 @@ def main(args): num_no_recompute=args.num_no_recompute, FA_dtype=FA_dtype, ) - + json_name = os.path.join(args.output_dir, "config.json") + save_diffusers_json(model.config, json_name) # mixed precision if args.precision == "fp32": model_dtype = get_precision(args.precision) From 1e3c33d6b342c8bc3ca7aa3a77c91cc58483a7b2 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:42:45 +0800 Subject: [PATCH 081/133] save_ema_only is False --- .../opensora_pku/opensora/train/commons.py | 6 +++++ .../opensora/train/train_causalvae.py | 24 +++++++++++++++---- .../opensora/train/train_t2v_diffusers.py | 2 +- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/examples/opensora_pku/opensora/train/commons.py b/examples/opensora_pku/opensora/train/commons.py index b4e49493c2..f819f4c249 100644 --- a/examples/opensora_pku/opensora/train/commons.py +++ b/examples/opensora_pku/opensora/train/commons.py @@ -197,6 +197,12 @@ def parse_train_args(parser): type=str2bool, help="whether save ckpt by steps. If False, save ckpt by epochs.", ) + parser.add_argument( + "--save_ema_only", + default=False, + type=str2bool, + help="whether save ema ckpt only. If False, and when ema during training is enabled, it will save both ema and non-ema.ckpt", + ) parser.add_argument( "--validate", default=False, diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index 48331be9a3..9670c8ca7c 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -514,13 +514,22 @@ def main(args): if rank_id == 0 and step_mode: cur_epoch = epoch + 1 if (cur_global_step % ckpt_save_interval == 0) or (cur_global_step == total_train_steps): + ae_with_loss.set_train(False) + disc_with_loss.set_train(False) ckpt_name = ( f"vae_3d-e{cur_epoch}.ckpt" if not use_step_unit else f"vae_3d-s{cur_global_step}.ckpt" ) + if not args.save_ema_only and ema is not None: + ckpt_manager.save( + ae_with_loss.autoencoder, + None, + ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"), + append_dict=None, + ) + if ema is not None: ema.swap_before_eval() - ae_with_loss.set_train(False) - disc_with_loss.set_train(False) + ckpt_manager.save(ae_with_loss.autoencoder, None, ckpt_name=ckpt_name, append_dict=None) if args.save_training_resume: ms.save_checkpoint( @@ -557,11 +566,18 @@ def main(args): if rank_id == 0 and not step_mode: if (cur_epoch % ckpt_save_interval == 0) or (cur_epoch == args.epochs): + ae_with_loss.set_train(False) + disc_with_loss.set_train(False) ckpt_name = f"vae_3d-e{cur_epoch}.ckpt" if not use_step_unit else f"vae_3d-s{cur_global_step}.ckpt" + if not args.save_ema_only and ema is not None: + ckpt_manager.save( + ae_with_loss.autoencoder, + None, + ckpt_name=ckpt_name.replace(".ckpt", "_nonema.ckpt"), + append_dict=None, + ) if ema is not None: ema.swap_before_eval() - ae_with_loss.set_train(False) - disc_with_loss.set_train(False) ckpt_manager.save(ae_with_loss.autoencoder, None, ckpt_name=ckpt_name, append_dict=None) if args.save_training_resume: ms.save_checkpoint( diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 679d9f5a95..54eefaab22 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -642,7 +642,7 @@ def main(args): ckpt_save_dir=ckpt_save_dir, output_dir=output_dir, ema=ema, - save_ema_only=False, + save_ema_only=args.save_ema_only, ckpt_save_policy="latest_k", ckpt_max_keep=ckpt_max_keep, step_mode=step_mode, From a9786af9022a38f4b41cd6d961084da4d55ca2d0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:46:04 +0800 Subject: [PATCH 082/133] update npu_config.init_env --- .../causalvideovae/sample/rec_video_vae.py | 17 +++-------------- .../opensora/sample/sample_text_embed.py | 12 ++---------- 2 files changed, 5 insertions(+), 24 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py b/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py index e4d91b24d2..7e09aee4fc 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py @@ -8,25 +8,14 @@ from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info from opensora.models.causalvideovae.model import ModelRegistry from opensora.models.causalvideovae.model.dataset_videobase import VideoDataset, create_dataloader -from opensora.utils.ms_utils import init_env +from opensora.npu_config import npu_config from opensora.utils.utils import get_precision from opensora.utils.video_utils import save_videos def main(args: argparse.Namespace): - rank_id, device_num = init_env( - args.mode, - seed=args.seed, - distributed=args.use_parallel, - device_target=args.device, - max_device_memory=args.max_device_memory, - parallel_mode=args.parallel_mode, - precision_mode=args.precision_mode, - sp_size=args.sp_size, - jit_level=args.jit_level, - jit_syntax_level=args.jit_syntax_level, - ) - + rank_id, device_num = npu_config.set_npu_env(args) + npu_config.print_ops_dtype_info() real_video_dir = args.real_video_dir generated_video_dir = args.generated_video_dir sample_rate = args.sample_rate diff --git a/examples/opensora_pku/opensora/sample/sample_text_embed.py b/examples/opensora_pku/opensora/sample/sample_text_embed.py index 688cd76aec..61aba003e9 100644 --- a/examples/opensora_pku/opensora/sample/sample_text_embed.py +++ b/examples/opensora_pku/opensora/sample/sample_text_embed.py @@ -17,8 +17,8 @@ sys.path.append(os.path.abspath("./")) from opensora.dataset.text_dataset import create_dataloader from opensora.dataset.transform import t5_text_preprocessing as text_preprocessing +from opensora.npu_config import npu_config from opensora.utils.message_utils import print_banner -from opensora.utils.ms_utils import init_env from opensora.utils.utils import get_precision from transformers import AutoTokenizer @@ -47,15 +47,7 @@ def read_captions_from_txt(path): def main(args): set_logger(name="", output_dir="logs/infer_mt5") - - rank_id, device_num = init_env( - mode=args.mode, - seed=args.seed, - distributed=args.use_parallel, - device_target=args.device_target, - jit_level=args.jit_level, - jit_syntax_level=args.jit_syntax_level, - ) + rank_id, device_num = npu_config.set_npu_env(args) print(f"rank_id {rank_id}, device_num {device_num}") # build dataloader for large amount of captions From 3542a62b9311fba26620f01840322bd7307411dd Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:47:40 +0800 Subject: [PATCH 083/133] ema is turned off by default --- .../scripts/causalvae/train_with_gan_loss.sh | 2 +- .../train_with_gan_loss_multi_device.sh | 2 +- .../multi-devices/sample_t2i_1x480p_ddp.sh | 21 ------------------- .../multi-devices/train_debug.sh | 2 +- .../multi-devices/train_t2v_stage1.sh | 2 +- .../multi-devices/train_t2v_stage2.sh | 2 +- .../multi-devices/train_t2v_stage3.sh | 2 +- .../single-device/train_t2v_stage1.sh | 2 +- .../single-device/train_t2v_stage2.sh | 2 +- .../single-device/train_t2v_stage3.sh | 2 +- 10 files changed, 9 insertions(+), 30 deletions(-) delete mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2i_1x480p_ddp.sh diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh index 931d3ad985..37c506f752 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh @@ -22,7 +22,7 @@ python opensora/train/train_causalvae.py \ --init_loss_scale 65536 \ --jit_level "O0" \ --use_discriminator True \ - --use_ema True\ + --use_ema False \ --ema_decay 0.999 \ --perceptual_weight 1.0 \ --loss_type l1 \ diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh index 73215d0767..3e349db12f 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh @@ -30,7 +30,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --jit_level "O0" \ --use_discriminator True \ --use_parallel True \ - --use_ema True\ + --use_ema False \ --ema_decay 0.999 \ --perceptual_weight 1.0 \ --loss_type l1 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2i_1x480p_ddp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2i_1x480p_ddp.sh deleted file mode 100644 index 619b817a9f..0000000000 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2i_1x480p_ddp.sh +++ /dev/null @@ -1,21 +0,0 @@ - -msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./sample_videos/prompt_list_0_1x480p_ddp/parallel_logs/" \ - opensora/sample/sample_t2v.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.2.0/1x480p \ - --num_frames 1 \ - --height 480 \ - --width 640 \ - --cache_dir "./" \ - --text_encoder_name google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae\ - --save_img_path "./sample_videos/prompt_list_0_1x480p_ddp" \ - --fps 24 \ - --guidance_scale 4.5 \ - --num_sampling_steps 100 \ - --enable_tiling \ - --max_sequence_length 512 \ - --sample_method EulerAncestralDiscrete \ - --model_type "dit" \ - --use_parallel True \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh index 02ac55514f..2b000e91c9 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh @@ -35,7 +35,7 @@ msrun --bind_core=True --worker_num=4 --local_worker_num=4 --master_port=6000 -- --use_image_num 0 \ --cfg 0.1 \ --snr_gamma 5.0 \ - --use_ema True\ + --use_ema False \ --ema_start_step 0 \ --enable_tiling \ --tile_overlap_factor 0.125 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh index ce393f2490..27fe901180 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh @@ -37,7 +37,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --cfg 0.1 \ --snr_gamma 5.0 \ --rescale_betas_zero_snr \ - --use_ema True\ + --use_ema False \ --ema_start_step 0 \ --enable_tiling \ --tile_overlap_factor 0.125 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index 988759a925..ae412a8d1e 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -38,7 +38,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --cfg 0.1 \ --snr_gamma 5.0 \ --rescale_betas_zero_snr \ - --use_ema True\ + --use_ema False \ --ema_start_step 0 \ --enable_tiling \ --tile_overlap_factor 0.125 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index 8cc16c4978..148e0eae53 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -37,7 +37,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --cfg 0.1 \ --snr_gamma 5.0 \ --rescale_betas_zero_snr \ - --use_ema True\ + --use_ema False \ --ema_start_step 0 \ --enable_tiling \ --tile_overlap_factor 0.125 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh index fe6f01b9d9..f43df4bdee 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh @@ -34,7 +34,7 @@ python opensora/train/train_t2v_diffusers.py \ --cfg 0.1 \ --snr_gamma 5.0 \ --rescale_betas_zero_snr \ - --use_ema True\ + --use_ema False \ --ema_start_step 0 \ --enable_tiling \ --tile_overlap_factor 0.125 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh index 94e09d4246..80569bbe3c 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh @@ -34,7 +34,7 @@ python opensora/train/train_t2v_diffusers.py \ --cfg 0.1 \ --snr_gamma 5.0 \ --rescale_betas_zero_snr \ - --use_ema True\ + --use_ema False \ --ema_start_step 0 \ --enable_tiling \ --tile_overlap_factor 0.125 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh index b9cd232a32..5a033aa166 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh @@ -34,7 +34,7 @@ python opensora/train/train_t2v_diffusers.py \ --cfg 0.1 \ --snr_gamma 5.0 \ --rescale_betas_zero_snr \ - --use_ema True\ + --use_ema False \ --ema_start_step 0 \ --enable_tiling \ --tile_overlap_factor 0.125 \ From 7a256594f44e6f12e90c039d40ecef6da7f02d2a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:51:34 +0800 Subject: [PATCH 084/133] update sample & train scripts --- .../single-device/sample_t2i_1x320x320.sh | 28 +++++++++++++++++++ ...rain_t2v_stage1.sh => train_t2i_stage1.sh} | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh rename examples/opensora_pku/scripts/text_condition/single-device/{train_t2v_stage1.sh => train_t2i_stage1.sh} (96%) diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh new file mode 100644 index 0000000000..29b6489612 --- /dev/null +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh @@ -0,0 +1,28 @@ +# The DiT model is trained arbitrarily on stride=32. +# So keep the resolution of the inference a multiple of 32. Frames needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1 (image). + +export DEVICE_ID=0 +python opensora/sample/sample.py \ + --model_path LanguageBind/Open-Sora-Plan-v1.3.0/1x320x320 \ + --version v1_3 \ + --num_frames 1 \ + --height 320 \ + --width 320 \ + --text_encoder_name_1 google/mt5-xxl \ + --text_prompt examples/prompt_list_human_images.txt \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --save_img_path "./sample_videos/human_images" \ + --fps 18 \ + --guidance_scale 7.5 \ + --num_sampling_steps 100 \ + --enable_tiling \ + --max_sequence_length 512 \ + --sample_method EulerAncestralDiscrete \ + --seed 1234 \ + --num_samples_per_prompt 1 \ + --rescale_betas_zero_snr \ + --prediction_type "v_prediction" \ + --mode 1 \ + --precision bf16 \ + --ms_checkpoint ckpt/path \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh similarity index 96% rename from examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh rename to examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh index f43df4bdee..35b1bae836 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh @@ -7,7 +7,7 @@ python opensora/train/train_t2v_diffusers.py \ --text_encoder_name_1 google/mt5-xxl \ --cache_dir "./" \ --dataset t2v \ - --data "scripts/train_data/image_data_v1_2.txt" \ + --data "scripts/train_data/merge_data_human_image.txt" \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --sample_rate 1 \ From 03f68a63ef93351d4a2cee5e5f0b4baacbd77b6a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:52:45 +0800 Subject: [PATCH 085/133] update train script name --- .../multi-devices/{train_t2v_stage1.sh => train_t2i_stage1.sh} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/opensora_pku/scripts/text_condition/multi-devices/{train_t2v_stage1.sh => train_t2i_stage1.sh} (100%) diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh similarity index 100% rename from examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage1.sh rename to examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh From a2f1ad960490a56fa5b2cc287f46f3b90767666b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Sat, 30 Nov 2024 16:25:09 +0800 Subject: [PATCH 086/133] tile_overlap_factor remove --- examples/opensora_pku/README.md | 3 +-- examples/opensora_pku/scripts/causalvae/rec_video_folder.sh | 1 - .../scripts/text_condition/multi-devices/train_debug.sh | 1 - .../scripts/text_condition/multi-devices/train_t2i_stage1.sh | 1 - .../scripts/text_condition/multi-devices/train_t2v_stage2.sh | 1 - .../scripts/text_condition/multi-devices/train_t2v_stage3.sh | 1 - .../scripts/text_condition/single-device/train_t2i_stage1.sh | 1 - .../scripts/text_condition/single-device/train_t2v_stage2.sh | 1 - .../scripts/text_condition/single-device/train_t2v_stage3.sh | 1 - 9 files changed, 1 insertion(+), 10 deletions(-) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 4892957897..0e65c062d2 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -212,7 +212,7 @@ python opensora/sample/sample.py \ --mode 1 ``` You can change the `num_frames`, `height` and `width`. -Note that DiT model is trained arbitrarily on stride=32. +Note that DiT model is trained arbitrarily on stride=32. So keep the resolution of the inference a multiple of 32. `num_frames` needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1. @@ -345,7 +345,6 @@ python examples/rec_video_folder.py \ --num_workers 8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae \ --enable_tiling \ - --tile_overlap_factor 0.125 \ --save_memory \ --ms_checkpoint /path/to/ms/checkpoint \ ``` diff --git a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh index 002260191d..640c54a5be 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh +++ b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh @@ -13,6 +13,5 @@ python examples/rec_video_folder.py \ --ae "WFVAEModel_D8_4x8x8" \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --enable_tiling \ - --tile_overlap_factor 0.125 \ --mode 1 \ --jit_syntax_level lax \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh index 2b000e91c9..4f7e89e229 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh @@ -38,7 +38,6 @@ msrun --bind_core=True --worker_num=4 --local_worker_num=4 --master_port=6000 -- --use_ema False \ --ema_start_step 0 \ --enable_tiling \ - --tile_overlap_factor 0.125 \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh index 27fe901180..7e93b65445 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh @@ -40,7 +40,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --use_ema False \ --ema_start_step 0 \ --enable_tiling \ - --tile_overlap_factor 0.125 \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index ae412a8d1e..68d0399aec 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -41,7 +41,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --use_ema False \ --ema_start_step 0 \ --enable_tiling \ - --tile_overlap_factor 0.125 \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index 148e0eae53..b8a5178deb 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -40,7 +40,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --use_ema False \ --ema_start_step 0 \ --enable_tiling \ - --tile_overlap_factor 0.125 \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh index 35b1bae836..75d55b51f6 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh @@ -37,7 +37,6 @@ python opensora/train/train_t2v_diffusers.py \ --use_ema False \ --ema_start_step 0 \ --enable_tiling \ - --tile_overlap_factor 0.125 \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh index 80569bbe3c..4c577928e5 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh @@ -37,7 +37,6 @@ python opensora/train/train_t2v_diffusers.py \ --use_ema False \ --ema_start_step 0 \ --enable_tiling \ - --tile_overlap_factor 0.125 \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh index 5a033aa166..6a91ddb451 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh @@ -37,7 +37,6 @@ python opensora/train/train_t2v_diffusers.py \ --use_ema False \ --ema_start_step 0 \ --enable_tiling \ - --tile_overlap_factor 0.125 \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ From fbdeadc26835c494a9725953d9ae725fa2822d08 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:23:27 +0800 Subject: [PATCH 087/133] remove config.recompute --- examples/opensora_pku/opensora/train/train_causalvae.py | 6 +++++- examples/opensora_pku/opensora/train/train_t2v_diffusers.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index 9670c8ca7c..ac08250e02 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -8,6 +8,7 @@ import sys import time +import deepcopy import yaml import mindspore as ms @@ -91,7 +92,10 @@ def main(args): logger.warning(f"Model will be initialized from config file {args.model_config}.") ae = model_cls.from_config(args.model_config, dtype=dtype, use_recompute=args.use_recompute) json_name = os.path.join(args.output_dir, "config.json") - save_diffusers_json(ae.config, json_name) + config = deepcopy.copy(ae.config) + if hasattr(config, "recompute"): + del config.recompute + save_diffusers_json(config, json_name) if args.load_from_checkpoint is not None: ae.init_from_ckpt(args.load_from_checkpoint) # discriminator (D) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 54eefaab22..e59e7c9700 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -3,6 +3,7 @@ import os import sys +import deepcopy import yaml import mindspore as ms @@ -162,7 +163,10 @@ def main(args): FA_dtype=FA_dtype, ) json_name = os.path.join(args.output_dir, "config.json") - save_diffusers_json(model.config, json_name) + config = deepcopy.copy(model.config) + if hasattr(config, "recompute"): + del config.recompute + save_diffusers_json(config, json_name) # mixed precision if args.precision == "fp32": model_dtype = get_precision(args.precision) From a2abf1dda03c2360c5eac9ffa5895594b931a173 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:25:39 +0800 Subject: [PATCH 088/133] load ms_checkpoint filter --- examples/opensora_pku/examples/rec_image.py | 1 + examples/opensora_pku/examples/rec_video.py | 1 + examples/opensora_pku/examples/rec_video_folder.py | 1 + examples/opensora_pku/opensora/sample/rec_image.py | 3 ++- examples/opensora_pku/opensora/sample/rec_video.py | 3 ++- examples/opensora_pku/opensora/utils/sample_utils.py | 1 + 6 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/examples/rec_image.py b/examples/opensora_pku/examples/rec_image.py index c3acfff217..24e4a98b0c 100644 --- a/examples/opensora_pku/examples/rec_image.py +++ b/examples/opensora_pku/examples/rec_image.py @@ -76,6 +76,7 @@ def main(args): state_dict = dict( [k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items() ) + state_dict = dict([k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in state_dict.items()) else: state_dict = None vae = None diff --git a/examples/opensora_pku/examples/rec_video.py b/examples/opensora_pku/examples/rec_video.py index 28334b055d..7af5497474 100644 --- a/examples/opensora_pku/examples/rec_video.py +++ b/examples/opensora_pku/examples/rec_video.py @@ -114,6 +114,7 @@ def main(args): state_dict = dict( [k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items() ) + state_dict = dict([k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in state_dict.items()) else: state_dict = None diff --git a/examples/opensora_pku/examples/rec_video_folder.py b/examples/opensora_pku/examples/rec_video_folder.py index 9f6195ee96..1b809c7a24 100644 --- a/examples/opensora_pku/examples/rec_video_folder.py +++ b/examples/opensora_pku/examples/rec_video_folder.py @@ -60,6 +60,7 @@ def main(args): state_dict = dict( [k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items() ) + state_dict = dict([k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in state_dict.items()) else: state_dict = None diff --git a/examples/opensora_pku/opensora/sample/rec_image.py b/examples/opensora_pku/opensora/sample/rec_image.py index ed0e84e8b0..2f7c9f8822 100644 --- a/examples/opensora_pku/opensora/sample/rec_image.py +++ b/examples/opensora_pku/opensora/sample/rec_image.py @@ -72,8 +72,9 @@ def main(args): state_dict = ms.load_checkpoint(args.ms_checkpoint) state_dict = dict( - [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() + [k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items() ) + state_dict = dict([k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in state_dict.items()) else: state_dict = None kwarg = { diff --git a/examples/opensora_pku/opensora/sample/rec_video.py b/examples/opensora_pku/opensora/sample/rec_video.py index 17abb677fa..99a534f7c3 100644 --- a/examples/opensora_pku/opensora/sample/rec_video.py +++ b/examples/opensora_pku/opensora/sample/rec_video.py @@ -122,8 +122,9 @@ def main(args): state_dict = ms.load_checkpoint(args.ms_checkpoint) state_dict = dict( - [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() + [k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items() ) + state_dict = dict([k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in state_dict.items()) else: state_dict = None kwarg = { diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index 14ba261a9a..9b5969eea9 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -163,6 +163,7 @@ def prepare_pipeline(args): state_dict = dict( [k.replace("network.", "") if k.startswith("network.") else k, v] for k, v in state_dict.items() ) + state_dict = dict([k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in state_dict.items()) else: state_dict = None model_version = args.model_path.split("/")[-1] From c08d314e221553e20f0fd14a9e2427dfec5d0567 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:37:39 +0800 Subject: [PATCH 089/133] fix import error --- examples/opensora_pku/opensora/train/train_causalvae.py | 2 +- examples/opensora_pku/opensora/train/train_t2v_diffusers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index ac08250e02..471b088f98 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -7,8 +7,8 @@ import shutil import sys import time +from copy import deepcopy -import deepcopy import yaml import mindspore as ms diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index e59e7c9700..ac5dea49f2 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -2,8 +2,8 @@ import math import os import sys +from copy import deepcopy -import deepcopy import yaml import mindspore as ms From 47da9b45442ca8abd8227058cab796767e64e6e3 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:40:54 +0800 Subject: [PATCH 090/133] fix import error --- examples/opensora_pku/opensora/train/train_causalvae.py | 2 +- examples/opensora_pku/opensora/train/train_t2v_diffusers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index 471b088f98..bf8762ebd1 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -92,7 +92,7 @@ def main(args): logger.warning(f"Model will be initialized from config file {args.model_config}.") ae = model_cls.from_config(args.model_config, dtype=dtype, use_recompute=args.use_recompute) json_name = os.path.join(args.output_dir, "config.json") - config = deepcopy.copy(ae.config) + config = deepcopy(ae.config) if hasattr(config, "recompute"): del config.recompute save_diffusers_json(config, json_name) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index ac5dea49f2..4c64eba5a5 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -163,7 +163,7 @@ def main(args): FA_dtype=FA_dtype, ) json_name = os.path.join(args.output_dir, "config.json") - config = deepcopy.copy(model.config) + config = deepcopy(model.config) if hasattr(config, "recompute"): del config.recompute save_diffusers_json(config, json_name) From b337dc72857148fca1cadcbb12a812333ce265ba Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:49:54 +0800 Subject: [PATCH 091/133] remove vae use_recompute --- examples/opensora_pku/opensora/train/train_causalvae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index bf8762ebd1..d31a1bfddc 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -93,8 +93,8 @@ def main(args): ae = model_cls.from_config(args.model_config, dtype=dtype, use_recompute=args.use_recompute) json_name = os.path.join(args.output_dir, "config.json") config = deepcopy(ae.config) - if hasattr(config, "recompute"): - del config.recompute + if hasattr(config, "use_recompute"): + del config.use_recompute save_diffusers_json(config, json_name) if args.load_from_checkpoint is not None: ae.init_from_ckpt(args.load_from_checkpoint) From a6e9a46111a563ccab7a31a413ebb7e0581005d3 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 4 Dec 2024 16:16:57 +0800 Subject: [PATCH 092/133] allow sparse1d is False --- .../diffusion/opensora/modeling_opensora.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 31edf2913a..3f607cf7a3 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -135,6 +135,8 @@ def _init_patched_inputs(self): # save as attributes used in construct self.patch_size_t = self.config.patch_size_t self.patch_size = self.config.patch_size + self.sparse1d = self.config.sparse1d + self.sparse_n = self.config.sparse_n def recompute(self, b): if not b._has_config_recompute: @@ -400,7 +402,6 @@ def construct( encoder_hidden_states = encoder_hidden_states.swapaxes(0, 1).contiguous() timestep = timestep.view(batch_size, 6, -1).swapaxes(0, 1).contiguous() - sparse_mask = {} # if npu_config is None: # if get_sequence_parallel_state(): # head_num = self.config.num_attention_heads // hccl_info.world_size @@ -408,19 +409,22 @@ def construct( # head_num = self.config.num_attention_heads # else: head_num = None - for sparse_n in [1, 4]: - sparse_mask[sparse_n] = Attention.prepare_sparse_mask( - attention_mask, encoder_attention_mask, sparse_n, head_num - ) + sparse_mask = {} + if self.sparse1d: + for sparse_n in [1, 4]: + sparse_mask[sparse_n] = Attention.prepare_sparse_mask( + attention_mask, encoder_attention_mask, sparse_n, head_num + ) # 2. Blocks for i, block in enumerate(self.transformer_blocks): - if i > 1 and i < 30: - attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][ - block.attn1.processor.sparse_group - ] - else: - attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] + if self.sparse1d: + if i > 1 and i < 30: + attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][ + block.attn1.processor.sparse_group + ] + else: + attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] hidden_states = block( hidden_states, From 5ad9b457da5a82910ceb30b549f927d0435eaf32 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 4 Dec 2024 16:46:06 +0800 Subject: [PATCH 093/133] allow wavelet loss --- .../model/vae/modeling_wfvae.py | 30 +++++++++++++------ .../opensora/train/train_causalvae.py | 5 ++-- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py index bfa2ddbf6d..4cba8a06d9 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -519,7 +519,7 @@ def build_chunk_start_end(self, t, decoder_mode=False): return start_end def encode(self, x, sample_posterior=True): - posterior_mean, posterior_logvar = self._encode(x) + posterior_mean, posterior_logvar, _ = self._encode(x) if sample_posterior: z = self.sample(posterior_mean, posterior_logvar) else: @@ -533,13 +533,13 @@ def _encode(self, x): if self.use_tiling: h = self.tile_encode(x) - # l1, l2 = None, None + w_coeffs = None else: - h, _ = self.encoder(x) + h, w_coeffs = self.encoder(x) if self.use_quant_layer: h = self.quant_conv(h) posterior_mean, posterior_logvar = mint.split(h, [h.shape[1] // 2, h.shape[1] // 2], dim=1) - return posterior_mean, posterior_logvar + return posterior_mean, posterior_logvar, w_coeffs def tile_encode(self, x): b, c, t, h, w = x.shape @@ -557,6 +557,11 @@ def tile_encode(self, x): return mint.cat(result, dim=2) def decode(self, z): + dec, _ = self._decode(z) + + return dec + + def _decode(self, z): self._empty_causal_cached(self.decoder) self._set_first_chunk(True) @@ -566,9 +571,9 @@ def decode(self, z): else: if self.use_quant_layer: z = self.post_quant_conv(z) - dec, _ = self.decoder(z) + dec, w_coeffs = self.decoder(z) - return dec + return dec, w_coeffs def tile_decode(self, x): b, c, t, h, w = x.shape @@ -606,15 +611,22 @@ def sample(self, mean, logvar): def construct(self, input, sample_posterior=True): # overall pass, mostly for training - posterior_mean, posterior_logvar = self._encode(input) + posterior_mean, posterior_logvar, encoder_w_coeffs = self._encode(input) if sample_posterior: z = self.sample(posterior_mean, posterior_logvar) else: z = posterior_mean - recons = self.decode(z) + recons, decoder_w_coeffs = self._decode(z) + if encoder_w_coeffs is not None and decoder_w_coeffs is not None: + assert len(encoder_w_coeffs) == 2 and len(decoder_w_coeffs) == 2 + e_l1, e_l2 = encoder_w_coeffs + d_l1, d_l2 = decoder_w_coeffs + w_coeffs = [e_l1, d_l1, e_l2, d_l2] + else: + w_coeffs = None - return recons, posterior_mean, posterior_logvar + return recons, posterior_mean, posterior_logvar, w_coeffs def get_last_layer(self): if hasattr(self.decoder.conv_out, "conv"): diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index d31a1bfddc..5f333368ad 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -118,8 +118,9 @@ def main(args): # 3. build net with loss (core) # G with loss - if args.wavelet_loss: - logger.warning("wavelet_loss is not implemented, and will be ignored.") + if args.wavelet_weight != 0 and ae.use_tiling: + logger.warning("Wavelet loss and use_tiling cannot be enabled in the same time! wavelet_weight is set to zero.") + args.wavelet_weight = 0.0 ae_with_loss = GeneratorWithLoss( ae, discriminator=disc, From 2ac4c0d8bfcb47329a80830b90a8cf11463d3ca3 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 4 Dec 2024 16:51:25 +0800 Subject: [PATCH 094/133] print wavelet loss --- .../models/causalvideovae/model/losses/net_with_loss.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py index 915de1886e..39ac568d1b 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py @@ -149,6 +149,8 @@ def loss_function( wl_loss_l2 = mint.sum(l1(wavelet_coeffs[0], wavelet_coeffs[1])) / bs wl_loss_l3 = mint.sum(l1(wavelet_coeffs[2], wavelet_coeffs[3])) / bs wl_loss = wl_loss_l2 + wl_loss_l3 + if self.print_losses: + print(f"wl_loss {wl_loss.asnumpy()}") else: wl_loss = 0 loss = mean_weighted_nll_loss + self.kl_weight * kl_loss + self.wavelet_weight * wl_loss From c982edb484fb773b1ba96dd8940cc1501813705b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 4 Dec 2024 17:16:16 +0800 Subject: [PATCH 095/133] fix disc no_grad --- .../model/losses/net_with_loss.py | 5 ++++- .../diffusion/opensora/net_with_loss.py | 22 ++----------------- .../opensora_pku/opensora/utils/ms_utils.py | 20 +++++++++++++++++ 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py index 39ac568d1b..83dc5ba81a 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/net_with_loss.py @@ -1,3 +1,5 @@ +from opensora.utils.ms_utils import no_grad + import mindspore as ms from mindspore import mint, nn, ops @@ -271,7 +273,8 @@ def construct(self, x: ms.Tensor, global_step=-1, cond=None): """ # 1. AE forward, get posterior (mean, logvar) and recons - recons, mean, logvar = ops.stop_gradient(self.autoencoder(x)) + with no_grad(): + recons = ops.stop_gradient(self.autoencoder(x)[0]) if x.ndim >= 5 and not self.use_3d_disc: # use 2D discriminator diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py index cc4665bff0..9c5c47c033 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py @@ -2,9 +2,10 @@ from opensora.acceleration.communications import prepare_parallel_data from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info +from opensora.utils.ms_utils import no_grad import mindspore as ms -from mindspore import _no_grad, mint, nn, ops +from mindspore import mint, nn, ops from mindone.diffusers.training_utils import compute_snr @@ -13,25 +14,6 @@ logger = logging.getLogger(__name__) -@ms.jit_class -class no_grad(_no_grad): - """ - A context manager that suppresses gradient memory allocation in PyNative mode. - """ - - def __init__(self): - super().__init__() - self._pynative = ms.get_context("mode") == ms.PYNATIVE_MODE - - def __enter__(self): - if self._pynative: - super().__enter__() - - def __exit__(self, *args): - if self._pynative: - super().__exit__(*args) - - class DiffusionWithLoss(nn.Cell): """An training pipeline for diffusion model diff --git a/examples/opensora_pku/opensora/utils/ms_utils.py b/examples/opensora_pku/opensora/utils/ms_utils.py index 52efa2ebb0..6c8e9659c2 100644 --- a/examples/opensora_pku/opensora/utils/ms_utils.py +++ b/examples/opensora_pku/opensora/utils/ms_utils.py @@ -5,6 +5,7 @@ from opensora.acceleration.parallel_states import initialize_sequence_parallel_state import mindspore as ms +from mindspore import _no_grad from mindspore.communication.management import get_group_size, get_rank, init from mindone.utils.seed import set_random_seed @@ -170,3 +171,22 @@ def init_env( ) initialize_sequence_parallel_state(sp_size) return rank_id, device_num + + +@ms.jit_class +class no_grad(_no_grad): + """ + A context manager that suppresses gradient memory allocation in PyNative mode. + """ + + def __init__(self): + super().__init__() + self._pynative = ms.get_context("mode") == ms.PYNATIVE_MODE + + def __enter__(self): + if self._pynative: + super().__enter__() + + def __exit__(self, *args): + if self._pynative: + super().__exit__(*args) From d37ed03eee43a8fd9b044cb9e2595b8b59af9b52 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 5 Dec 2024 20:04:14 +0800 Subject: [PATCH 096/133] fix vae tile decode error --- .../opensora/models/causalvideovae/model/vae/modeling_wfvae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py index 4cba8a06d9..c43b878dee 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -567,7 +567,7 @@ def _decode(self, z): if self.use_tiling: dec = self.tile_decode(z) - # l1, l2 = None, None + w_coeffs = None else: if self.use_quant_layer: z = self.post_quant_conv(z) From 61468fa9fcd22fabbbb9f02e947153b1ce67e6b0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 5 Dec 2024 20:10:12 +0800 Subject: [PATCH 097/133] edit attention_mask shape --- .../models/diffusion/opensora/modules.py | 5 ++++- examples/opensora_pku/opensora/npu_config.py | 19 ++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index ebb0274f48..7d9bf6dc3c 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -99,7 +99,6 @@ def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_n True: (attention_mask_sparse_1d_group, encoder_attention_mask_sparse_1d_group), } - # NO USE YET def prepare_attention_mask( self, attention_mask: ms.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> ms.Tensor: @@ -279,6 +278,10 @@ def __call__( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) # BSH + # attention_mask shape + if attention_mask.ndim == 3: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) + # print(f"hidden_states.shape {hidden_states.shape}") #BSH query = attn.to_q(hidden_states) diff --git a/examples/opensora_pku/opensora/npu_config.py b/examples/opensora_pku/opensora/npu_config.py index 0f540c9308..7690c7c8ac 100644 --- a/examples/opensora_pku/opensora/npu_config.py +++ b/examples/opensora_pku/opensora/npu_config.py @@ -261,6 +261,9 @@ def ms_flash_attention( if attention_mask is not None: # flip mask, since ms FA treats 1 as discard, 0 as retain. attention_mask = ~attention_mask if attention_mask.dtype == ms.bool_ else 1 - attention_mask + assert ( + attention_mask.ndim == 4 + ), f"Expect attention mask has four dimensions, but got {attention_mask.shape}" # (b, 1, 1, k_n) - > (b, 1, q_n, k_n), manual broadcast if attention_mask.shape[-2] == 1: attention_mask = mint.tile(attention_mask.bool(), (1, 1, query_tokens, 1)) @@ -334,10 +337,20 @@ def trans_tensor_shape(x, layout, head_num): attn_bias.masked_fill(~temp_mask, npu_config.inf_float) attn_bias.to(query.dtype) - if attention_mask is not None: + elif attention_mask is not None: + # check attention_mask shape (bs, head_num, query_len, token_length) assert ( - not self.enable_FA - ) and attention_mask.dtype != ms.bool, "attention_mask must not be bool type when use this function" + attention_mask.ndim == 4 + ), f"Expect attention mask has four dimensions, but got {attention_mask.shape}" + if attention_mask.shape[1] == 1: + attention_mask = attention_mask.repeat_interleave(head_num, 1) + else: + assert ( + attention_mask.shape[1] == head_num + ), f"Expect attention_mask to be like (bs, 1, query_len, key_len), but got {attention_mask.shape}" + # fill in with -inf + attn_bias = mint.zeros(attention_mask.shape).masked_fill(attention_mask.to(ms.bool_), npu_config.inf_float) + attn_bias.to(query.dtype) attn_weight += attn_bias attn_weight = mint.nn.functional.softmax(attn_weight, dim=-1) From 3f325d9f8a3e37c5e356f41aa376d1ca2d7646af Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 6 Dec 2024 15:31:06 +0800 Subject: [PATCH 098/133] allow set sparse_n other than 4 --- .../opensora/models/diffusion/opensora/modeling_opensora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 3f607cf7a3..573d8e7106 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -63,6 +63,7 @@ def __init__( self.gradient_checkpointing = use_recompute # NEW self.use_recompute = use_recompute # NEW self.FA_dtype = FA_dtype # NEW + self.sparse_n = sparse_n # NEW self._init_patched_inputs() if self.use_recompute: @@ -411,7 +412,7 @@ def construct( head_num = None sparse_mask = {} if self.sparse1d: - for sparse_n in [1, 4]: + for sparse_n in [1, self.sparse_n]: sparse_mask[sparse_n] = Attention.prepare_sparse_mask( attention_mask, encoder_attention_mask, sparse_n, head_num ) From 3f170174f778d0da6e0d9c0bd798100ffd02cb94 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 11 Dec 2024 19:00:27 +0800 Subject: [PATCH 099/133] remove redudant args & allow rectified flow in training --- .../diffusion/opensora/net_with_loss.py | 215 +++++++++++++----- .../opensora/train/train_t2v_diffusers.py | 93 +------- examples/opensora_pku/opensora/utils/utils.py | 38 ++++ .../multi-devices/train_debug.sh | 1 - .../multi-devices/train_t2i_stage1.sh | 1 - .../multi-devices/train_t2v_stage2.sh | 1 - .../multi-devices/train_t2v_stage3.sh | 1 - .../single-device/train_t2i_stage1.sh | 1 - .../single-device/train_t2v_stage2.sh | 1 - .../single-device/train_t2v_stage3.sh | 1 - 10 files changed, 204 insertions(+), 149 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py index 9c5c47c033..191ee64565 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py @@ -1,8 +1,10 @@ import logging +import math from opensora.acceleration.communications import prepare_parallel_data from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info from opensora.utils.ms_utils import no_grad +from opensora.utils.utils import explicit_uniform_sampling, get_sigmas import mindspore as ms from mindspore import mint, nn, ops @@ -14,6 +16,46 @@ logger = logging.getLogger(__name__) +def compute_density_for_timestep_sampling( + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None +): + """ + Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = mint.normal(mean=logit_mean, std=logit_std, size=(batch_size,)) + u = mint.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = mint.rand(size=(batch_size,)) + u = 1 - u - mode_scale * (mint.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = mint.rand(size=(batch_size,)) + return u + + +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): + """ + Computes loss weighting scheme for SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "sigma_sqrt": + weighting = (sigmas**-2.0).float() + elif weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = mint.ones_like(sigmas) + return weighting + + class DiffusionWithLoss(nn.Cell): """An training pipeline for diffusion model @@ -24,6 +66,7 @@ class DiffusionWithLoss(nn.Cell): text_encoder / text_encoder_2 (nn.Cell): A text encoding model which accepts token ids and returns text embeddings in shape (T, D). T is the number of tokens, and D is the embedding dimension. train_with_embed (bool): whether to train with embeddings (no need vae and text encoder to extract latent features and text embeddings) + rf_scheduler (bool): whether to apply rectified flow scheduler. """ def __init__( @@ -38,7 +81,14 @@ def __init__( use_image_num: int = 0, dtype=ms.float32, noise_offset: float = 0.0, + rf_scheduler: bool = False, snr_gamma=None, + rank_id: int = 0, + device_num: int = 1, + logit_mean: float = 0.0, + logit_std: float = 1.0, + weighting_scheme: str = "logit_normal", + mode_scale: float = 1.29, ): super().__init__() # TODO: is set_grad() necessary? @@ -48,7 +98,14 @@ def __init__( self.prediction_type = self.noise_scheduler.config.prediction_type self.num_train_timesteps = self.noise_scheduler.config.num_train_timesteps self.noise_offset = noise_offset + self.rf_scheduler = rf_scheduler + self.rank_id = rank_id + self.device_num = device_num self.snr_gamma = snr_gamma + self.logit_mean = logit_mean + self.logit_std = logit_std + self.weighting_scheme = weighting_scheme + self.mode_scale = mode_scale self.text_encoder = text_encoder self.dtype = dtype @@ -166,9 +223,10 @@ def compute_loss(self, x, attention_mask, text_embed, encoder_attention_mask): use_image_num = self.use_image_num noise = ops.randn_like(x) bsz = x.shape[0] - if self.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += self.noise_offset * ops.randn((bsz, x.shape[1], 1, 1, 1), dtype=x.dtype) + if not self.rf_scheduler: + if self.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += self.noise_offset * ops.randn((bsz, x.shape[1], 1, 1, 1), dtype=x.dtype) current_step_frame = x.shape[2] if get_sequence_parallel_state() and current_step_frame > 1: x = self.all_gather(x[None])[0] @@ -180,11 +238,35 @@ def compute_loss(self, x, attention_mask, text_embed, encoder_attention_mask): encoder_attention_mask, use_image_num, ) = prepare_parallel_data(x, noise, text_embed, attention_mask, encoder_attention_mask, use_image_num) - - t = ops.randint(0, self.num_train_timesteps, (x.shape[0],), dtype=ms.int32) - if get_sequence_parallel_state(): - t = self.reduce_t(t) % self.num_train_timesteps - x_t = self.noise_scheduler.add_noise(x, noise, t) + if not self.rf_scheduler: + # sample a random timestep for each image without bias + t = explicit_uniform_sampling( + T=self.num_train_timesteps, + n=self.device_num, + rank=self.rank_id, + bsz=bsz, + ) + # t = ops.randint(0, self.num_train_timesteps, (x.shape[0],), dtype=ms.int32) + if get_sequence_parallel_state(): + t = self.reduce_t(t) % self.num_train_timesteps + x_t = self.noise_scheduler.add_noise(x, noise, t) + else: + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=self.weighting_scheme, + batch_size=bsz, + logit_mean=self.logit_mean, + logit_std=self.logit_std, + mode_scale=self.mode_scale, + ) + indices = u * self.num_train_timesteps + t = self.noise_scheduler.timesteps[indices] + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(self.noise_scheduler, t, n_dim=x.ndim, dtype=x.dtype) + x_t = (1.0 - sigmas) * x + sigmas * noise # latte forward input match # text embed: (b n_tokens d) -> (b 1 n_tokens d) @@ -197,58 +279,81 @@ def compute_loss(self, x, attention_mask, text_embed, encoder_attention_mask): encoder_attention_mask=encoder_attention_mask, use_image_num=use_image_num, ) - - if self.prediction_type == "epsilon": - target = noise - elif self.prediction_type == "v_prediction": - target = self.noise_scheduler.get_velocity(x, noise, t) - elif self.prediction_type == "sample": - # We set the target to latents here, but the model_pred will return the noise sample prediction. - target = x - # We will have to subtract the noise residual from the prediction to get the target sample. - model_pred = model_pred - noise - else: - raise ValueError(f"Unknown prediction type {self.prediction_type}") - # comment it to avoid graph syntax error - # if attention_mask is not None and (attention_mask.bool()).all(): - # attention_mask = None - if get_sequence_parallel_state(): - assert (attention_mask.bool()).all() - # assert attention_mask is None - attention_mask = None - # (b c t h w), - bsz, c, _, _, _ = model_pred.shape - if attention_mask is not None: - attention_mask = attention_mask.unsqueeze(1).float().repeat(c, axis=1) # b t h w -> b c t h w - attention_mask = attention_mask.reshape(bsz, -1) - - if self.snr_gamma is None: - # model_pred: b c t h w, attention_mask: b t h w - loss = ops.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.reshape(bsz, -1) + if not self.rf_scheduler: + if self.prediction_type == "epsilon": + target = noise + elif self.prediction_type == "v_prediction": + target = self.noise_scheduler.get_velocity(x, noise, t) + elif self.prediction_type == "sample": + # We set the target to latents here, but the model_pred will return the noise sample prediction. + target = x + # We will have to subtract the noise residual from the prediction to get the target sample. + model_pred = model_pred - noise + else: + raise ValueError(f"Unknown prediction type {self.prediction_type}") + # comment it to avoid graph syntax error + # if attention_mask is not None and (attention_mask.bool()).all(): + # attention_mask = None + if get_sequence_parallel_state(): + assert (attention_mask.bool()).all() + # assert attention_mask is None + attention_mask = None + # (b c t h w), + bsz, c, _, _, _ = model_pred.shape if attention_mask is not None: - loss = (loss * attention_mask).sum() / attention_mask.sum() # mean loss on unpad patches + attention_mask = attention_mask.unsqueeze(1).float().repeat(c, axis=1) # b t h w -> b c t h w + attention_mask = attention_mask.reshape(bsz, -1) + + if self.snr_gamma is None: + # model_pred: b c t h w, attention_mask: b t h w + loss = ops.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.reshape(bsz, -1) + if attention_mask is not None: + loss = (loss * attention_mask).sum() / attention_mask.sum() # mean loss on unpad patches + else: + loss = loss.mean() else: - loss = loss.mean() + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(self.noise_scheduler, t) + mse_loss_weights = ops.stack([snr, self.snr_gamma * ops.ones_like(t)], axis=1).min(axis=1)[0] + if self.prediction_type == "epsilon": + mse_loss_weights = mse_loss_weights / snr + elif self.prediction_type == "v_prediction": + mse_loss_weights = mse_loss_weights / (snr + 1) + loss = ops.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.reshape(bsz, -1) + mse_loss_weights = mse_loss_weights.reshape(bsz, 1) + if attention_mask is not None: + loss = ( + loss * attention_mask * mse_loss_weights + ).sum() / attention_mask.sum() # mean loss on unpad patches + else: + loss = (loss * mse_loss_weights).mean() else: - # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. - # Since we predict the noise instead of x_0, the original formulation is slightly changed. - # This is discussed in Section 4.2 of the same paper. - snr = compute_snr(self.noise_scheduler, t) - mse_loss_weights = ops.stack([snr, self.snr_gamma * ops.ones_like(t)], axis=1).min(axis=1)[0] - if self.prediction_type == "epsilon": - mse_loss_weights = mse_loss_weights / snr - elif self.prediction_type == "v_prediction": - mse_loss_weights = mse_loss_weights / (snr + 1) - loss = ops.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.reshape(bsz, -1) - mse_loss_weights = mse_loss_weights.reshape(bsz, 1) + if mint.all(attention_mask.bool()): + attention_mask = None + + b, c, _, _, _ = model_pred.shape if attention_mask is not None: - loss = ( - loss * attention_mask * mse_loss_weights - ).sum() / attention_mask.sum() # mean loss on unpad patches + attention_mask = attention_mask.unsqueeze(1).float().repeat(c, axis=1) # b t h w -> b c t h w + attention_mask = attention_mask.reshape(b, -1) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=self.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - x + + # Compute regular loss. + loss_mse = (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1) + if attention_mask is not None: + loss = (loss_mse * attention_mask).sum() / attention_mask.sum() else: - loss = (loss * mse_loss_weights).mean() + loss = loss_mse.mean() + return loss diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 4c64eba5a5..a11fa9e6e1 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -98,6 +98,9 @@ def main(args): vae = None else: print_banner("vae init") + if args.vae_fp32: + logger.info("Force VAE running in FP32") + args.vae_precision = "fp32" vae_dtype = get_precision(args.vae_precision) kwarg = { "state_dict": None, @@ -256,15 +259,9 @@ def main(args): noise_scheduler = DDPMScheduler(**kwargs) elif args.rf_scheduler: noise_scheduler = FlowMatchEulerDiscreteScheduler() - # noise_scheduler_copy = copy.deepcopy(noise_scheduler) else: noise_scheduler = DDPMScheduler(**kwargs) - # Get the target for loss depending on the prediction type - if args.prediction_type is not None: - # set prediction_type of scheduler if defined - noise_scheduler.register_to_config(prediction_type=args.prediction_type) - assert args.use_image_num >= 0, f"Expect to have use_image_num>=0, but got {args.use_image_num}" if args.use_image_num > 0: logger.info("Enable video-image-joint training") @@ -285,6 +282,9 @@ def main(args): dtype=model_dtype, noise_offset=args.noise_offset, snr_gamma=args.snr_gamma, + rf_scheduler=args.rf_scheduler, + rank_id=rank_id, + device_num=device_num, ) latent_diffusion_eval, metrics, eval_indexes = None, None, None @@ -805,80 +805,6 @@ def parse_t2v_train_args(parser): parser.add_argument("--enable_profiling", action="store_true") parser.add_argument("--num_sampling_steps", type=int, default=20) parser.add_argument("--guidance_scale", type=float, default=4.5) - parser.add_argument( - "--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store.") - ) - - # optimizer & scheduler - parser.add_argument( - "--optimizer", type=str, default="adamW", help='The optimizer type to use. Choose between ["AdamW", "prodigy"]' - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", - ) - parser.add_argument( - "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." - ) - parser.add_argument( - "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." - ) - parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-02, help="Weight decay to use for unet params") - parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder" - ) - parser.add_argument( - "--adam_epsilon", type=float, default=1e-15, help="Epsilon value for the Adam optimizer and Prodigy optimizers." - ) - parser.add_argument( - "--prodigy_use_bias_correction", - type=bool, - default=True, - help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", - ) - parser.add_argument( - "--prodigy_safeguard_warmup", - type=bool, - default=True, - help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. Ignored if optimizer is adamW", - ) - parser.add_argument( - "--prodigy_beta3", - type=float, - default=None, - help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " - "uses the value of square root of beta2. Ignored if optimizer is adamW", - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--mixed_precision", - type=str, - default=None, - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." - ), - ) - - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") - ######################## parser.add_argument("--output_dir", default="outputs/", help="The directory where training results are saved.") parser.add_argument("--dataset", type=str, default="t2v") @@ -927,12 +853,6 @@ def parse_t2v_train_args(parser): parser.add_argument("--use_rope", action="store_true") parser.add_argument("--pretrained", type=str, default=None) - parser.add_argument( - "--enable_stable_fp32", - default=True, - type=str2bool, - help="Whether to some cells, e.g., LayerNorm, silu, into fp32", - ) parser.add_argument("--tile_overlap_factor", type=float, default=0.25) parser.add_argument("--enable_tiling", action="store_true") @@ -941,7 +861,6 @@ def parse_t2v_train_args(parser): parser.add_argument("--model_max_length", type=int, default=512) parser.add_argument("--multi_scale", action="store_true") - # parser.add_argument("--enable_tracker", action="store_true") parser.add_argument("--use_image_num", type=int, default=0) parser.add_argument("--use_img_from_vid", action="store_true") parser.add_argument( diff --git a/examples/opensora_pku/opensora/utils/utils.py b/examples/opensora_pku/opensora/utils/utils.py index a6890c975b..7e871cf8f2 100644 --- a/examples/opensora_pku/opensora/utils/utils.py +++ b/examples/opensora_pku/opensora/utils/utils.py @@ -3,6 +3,7 @@ import html import json import logging +import random import re import urllib.parse as ul from multiprocessing import Pool @@ -48,6 +49,43 @@ def to_2tuple(x): return (x, x) +def explicit_uniform_sampling(T, n, rank, bsz): + """ + Explicit Uniform Sampling with integer timesteps and MindSpore. + + Args: + T (int): Maximum timestep value. + n (int): Number of ranks (data parallel processes). + rank (int): The rank of the current process (from 0 to n-1). + bsz (int): Batch size, number of timesteps to return. + + Returns: + ms.Tensor: A tensor of shape (bsz,) containing uniformly sampled integer timesteps + within the rank's interval. + """ + interval_size = T / n # Integer division to ensure boundaries are integers + lower_bound = interval_size * rank - 0.5 + upper_bound = interval_size * (rank + 1) - 0.5 + sampled_timesteps = [round(random.uniform(lower_bound, upper_bound)) for _ in range(bsz)] + + # Uniformly sample within the rank's interval, returning integers + sampled_timesteps = ms.Tensor([round(random.uniform(lower_bound, upper_bound)) for _ in range(bsz)], dtype=ms.int32) + # sampled_timesteps = sampled_timesteps.long() + return sampled_timesteps + + +def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=ms.float32): + sigmas = noise_scheduler.sigmas.to(dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps + + step_indices = [(schedule_timesteps == t).nonzero() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def get_experiment_dir(root_dir, args): # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained: # root_dir += '-WOPRE' diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh index 4f7e89e229..d79fb40d14 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh @@ -41,7 +41,6 @@ msrun --bind_core=True --worker_num=4 --local_worker_num=4 --master_port=6000 -- --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ - --enable_stable_fp32 True\ --ema_decay 0.999 \ --speed_factor 1.0 \ --drop_short_ratio 1.0 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh index 7e93b65445..71712662c1 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh @@ -43,7 +43,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ - --enable_stable_fp32 True\ --ema_decay 0.999 \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index 68d0399aec..7ac9fc44b5 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -44,7 +44,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ - --enable_stable_fp32 True\ --ema_decay 0.999 \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index b8a5178deb..05d544f23c 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -43,7 +43,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ - --enable_stable_fp32 True\ --ema_decay 0.999 \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh index 75d55b51f6..9e6bb6b1fb 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh @@ -40,7 +40,6 @@ python opensora/train/train_t2v_diffusers.py \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ - --enable_stable_fp32 True\ --ema_decay 0.999 \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh index 4c577928e5..e380c1704c 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh @@ -40,7 +40,6 @@ python opensora/train/train_t2v_diffusers.py \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ - --enable_stable_fp32 True\ --ema_decay 0.999 \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh index 6a91ddb451..1b7105f3fe 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh @@ -40,7 +40,6 @@ python opensora/train/train_t2v_diffusers.py \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ - --enable_stable_fp32 True\ --ema_decay 0.999 \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ From 93b45136aa60f4e5dface249bb4c486491b7b496 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 12 Dec 2024 10:41:42 +0800 Subject: [PATCH 100/133] allow different length --- examples/opensora_pku/opensora/eval/cal_lpips.py | 5 +++-- examples/opensora_pku/opensora/eval/cal_psnr.py | 5 +++-- examples/opensora_pku/opensora/eval/cal_ssim.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/opensora_pku/opensora/eval/cal_lpips.py b/examples/opensora_pku/opensora/eval/cal_lpips.py index 31f7e88368..3b752ff8d8 100644 --- a/examples/opensora_pku/opensora/eval/cal_lpips.py +++ b/examples/opensora_pku/opensora/eval/cal_lpips.py @@ -33,7 +33,7 @@ def trans(x): def calculate_lpips(videos1, videos2): # image should be RGB, IMPORTANT: normalized to [-1,1] - assert videos1.shape == videos2.shape + # assert videos1.shape == videos2.shape # videos [batch_size, timestamps, channel, h, w] @@ -53,7 +53,8 @@ def calculate_lpips(videos1, videos2): video2 = ms.Tensor(video2, dtype=ms.float32) lpips_results_of_a_video = [] - for clip_timestamp in range(len(video1)): + length = min(len(video1), len(video2)) + for clip_timestamp in range(length): # get a img # img [timestamps[x], channel, h, w] # img [channel, h, w] tensor diff --git a/examples/opensora_pku/opensora/eval/cal_psnr.py b/examples/opensora_pku/opensora/eval/cal_psnr.py index a12ffa351a..0f451070a0 100644 --- a/examples/opensora_pku/opensora/eval/cal_psnr.py +++ b/examples/opensora_pku/opensora/eval/cal_psnr.py @@ -24,7 +24,7 @@ def trans(x): def calculate_psnr(videos1, videos2): # videos [batch_size, timestamps, channel, h, w] - assert videos1.shape == videos2.shape + # assert videos1.shape == videos2.shape videos1 = trans(videos1) videos2 = trans(videos2) @@ -38,7 +38,8 @@ def calculate_psnr(videos1, videos2): video2 = videos2[video_num] psnr_results_of_a_video = [] - for clip_timestamp in range(len(video1)): + length = min(len(video1), len(video2)) + for clip_timestamp in range(length): # get a img # img [timestamps[x], channel, h, w] # img [channel, h, w] numpy diff --git a/examples/opensora_pku/opensora/eval/cal_ssim.py b/examples/opensora_pku/opensora/eval/cal_ssim.py index 4ad92deff8..2b0c064c3a 100644 --- a/examples/opensora_pku/opensora/eval/cal_ssim.py +++ b/examples/opensora_pku/opensora/eval/cal_ssim.py @@ -49,7 +49,7 @@ def trans(x): def calculate_ssim(videos1, videos2): # videos [batch_size, timestamps, channel, h, w] - assert videos1.shape == videos2.shape + # assert videos1.shape == videos2.shape videos1 = trans(videos1) videos2 = trans(videos2) @@ -63,7 +63,8 @@ def calculate_ssim(videos1, videos2): video2 = videos2[video_num] ssim_results_of_a_video = [] - for clip_timestamp in range(len(video1)): + length = min(len(video1), len(video2)) + for clip_timestamp in range(length): # get a img # img [timestamps[x], channel, h, w] # img [channel, h, w] numpy From ca4da34baae5297d1e4d53f68c5a46f0ad321c4a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 12 Dec 2024 10:45:08 +0800 Subject: [PATCH 101/133] allow parallel video reconstruction --- .../opensora_pku/scripts/causalvae/rec_video_folder.sh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh index 640c54a5be..3ea7b087f5 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh +++ b/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh @@ -1,10 +1,10 @@ -python examples/rec_video_folder.py \ +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="parallel_logs/" examples/rec_video_folder.py \ --batch_size 1 \ --real_video_dir datasets/UCF-101/ \ --data_file_path datasets/ucf101_test.csv \ --generated_video_dir recons/ucf101_test/ \ --device Ascend \ - --sample_fps 10 \ + --sample_fps 30 \ --sample_rate 1 \ --num_frames 25 \ --height 256 \ @@ -12,6 +12,7 @@ python examples/rec_video_folder.py \ --num_workers 8 \ --ae "WFVAEModel_D8_4x8x8" \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --enable_tiling \ --mode 1 \ --jit_syntax_level lax \ + --use_parallel True \ + # --ms_checkpoint path/to/ms/ckpt From 5a183a89185f94bfe6f21056d11ddbdecef14a53 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:44:38 +0800 Subject: [PATCH 102/133] print loss weight item and resume from log file --- .../opensora/train/train_causalvae.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_causalvae.py b/examples/opensora_pku/opensora/train/train_causalvae.py index 5f333368ad..8a3d10be58 100644 --- a/examples/opensora_pku/opensora/train/train_causalvae.py +++ b/examples/opensora_pku/opensora/train/train_causalvae.py @@ -9,6 +9,7 @@ import time from copy import deepcopy +import pandas as pd import yaml import mindspore as ms @@ -121,6 +122,20 @@ def main(args): if args.wavelet_weight != 0 and ae.use_tiling: logger.warning("Wavelet loss and use_tiling cannot be enabled in the same time! wavelet_weight is set to zero.") args.wavelet_weight = 0.0 + headers = [ + "perceptual loss weight", + "KL div loss weight", + "Wavelet Loss weight", + "Discriminator loss weight (start)", + ] + values = [ + "{:.2f}".format(args.perceptual_weight), + "{:.2f}".format(args.kl_weight), + "{:.2f}".format(args.wavelet_weight), + "{:.2f}({:d})".format(args.disc_weight, args.disc_start), + ] + df = pd.DataFrame([values], columns=headers) + print(df) ae_with_loss = GeneratorWithLoss( ae, discriminator=disc, @@ -443,9 +458,14 @@ def main(args): else: if not os.path.exists(f"{args.output_dir}/rank_{rank_id}"): os.makedirs(f"{args.output_dir}/rank_{rank_id}") - loss_log_file = open(f"{args.output_dir}/rank_{rank_id}/result.log", "w") - loss_log_file.write("step\tloss_ae\tloss_disc\ttrain_time(s)\n") - loss_log_file.flush() + if args.resume_from_checkpoint and os.path.exists(f"{args.output_dir}/rank_{rank_id}/result.log"): + # resume the loss log if it exists + loss_log_file = open(f"{args.output_dir}/rank_{rank_id}/result.log", "a") + else: + loss_log_file = open(f"{args.output_dir}/rank_{rank_id}/result.log", "w") + loss_log_file.write("step\tloss_ae\tloss_disc\ttrain_time(s)\n") + loss_log_file.flush() + if rank_id == 0: ckpt_manager = CheckpointManager(ckpt_dir, "latest_k", k=args.ckpt_max_keep) # output_numpy=True ? From 2aa786e5d2a5e68071e6c4f1bea001e7977d0239 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 12 Dec 2024 15:12:54 +0800 Subject: [PATCH 103/133] impr logging --- .../opensora_pku/opensora/train/train_t2v_diffusers.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index a11fa9e6e1..0b2c775745 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -537,6 +537,7 @@ def main(args): loss_scaler.loss_scale_value = loss_scale loss_scaler.cur_iter = cur_iter loss_scaler.last_overflow_iter = last_overflow_iter + logger.info(f"Resume training from {resume_ckpt}") # trainer (standalone and distributed) ema = ( @@ -705,12 +706,8 @@ def main(args): f"Num trainable params: {num_params_trainable}", f"Transformer model dtype: {model_dtype}", f"Transformer AMP level: {args.amp_level}" if not args.global_bf16 else "Global BF16: True", - f"VAE dtype: {vae_dtype} (amp level O2)" - + ( - f"\nText encoder dtype: {text_encoder_dtype} (amp level O2)" - if text_encoder_dtype is not None - else "" - ), + f"VAE dtype: {vae_dtype}" + + (f"\nText encoder dtype: {text_encoder_dtype}" if text_encoder_dtype is not None else ""), f"Learning rate: {learning_rate}", f"Instantaneous batch size per device: {args.train_batch_size}", f"Total train batch size (w. parallel, distributed & accumulation): {total_batch_size}", From e3dc5e6a319916f25d6a280c69dc57946e1aee4b Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:04:53 +0800 Subject: [PATCH 104/133] readme update to vae training --- examples/opensora_pku/README.md | 177 +++++++----------- .../causalvideovae/sample/rec_video_vae.py | 2 +- .../opensora/utils/sample_utils.py | 2 +- examples/opensora_pku/requirements.txt | 11 +- .../{ => multi-devices}/rec_video_folder.sh | 3 +- .../train_stage_1.sh} | 7 +- .../causalvae/multi-devices/train_stage_2.sh | 41 ++++ .../causalvae/multi-devices/train_stage_3.sh | 41 ++++ .../scripts/causalvae/release.json | 85 --------- .../causalvae/{ => single-device}/eval.sh | 0 .../{ => single-device}/rec_image.sh | 2 - .../{ => single-device}/rec_video.sh | 6 +- .../single-device/rec_video_folder.sh | 15 ++ .../train.sh} | 5 +- .../multi-devices/sample_t2v_93x640_ddp.sh | 1 - .../multi-devices/sample_t2v_93x640_sp.sh | 2 - .../multi-devices/train_debug.sh | 61 ------ .../multi-devices/train_t2i_stage1.sh | 2 - .../multi-devices/train_t2v_stage2.sh | 2 - .../multi-devices/train_t2v_stage3.sh | 2 - .../single-device/sample_debug.sh | 36 ---- .../single-device/sample_t2i_1x320x320.sh | 1 - .../single-device/sample_t2v_93x640.sh | 1 - .../sample_t2v_93x640_2texenc.sh | 1 - .../single-device/train_t2i_stage1.sh | 2 - .../single-device/train_t2v_stage2.sh | 2 - .../single-device/train_t2v_stage3.sh | 2 - 27 files changed, 174 insertions(+), 338 deletions(-) rename examples/opensora_pku/scripts/causalvae/{ => multi-devices}/rec_video_folder.sh (92%) rename examples/opensora_pku/scripts/causalvae/{train_with_gan_loss_multi_device.sh => multi-devices/train_stage_1.sh} (91%) create mode 100644 examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_2.sh create mode 100644 examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_3.sh delete mode 100644 examples/opensora_pku/scripts/causalvae/release.json rename examples/opensora_pku/scripts/causalvae/{ => single-device}/eval.sh (100%) rename examples/opensora_pku/scripts/causalvae/{ => single-device}/rec_image.sh (83%) rename examples/opensora_pku/scripts/causalvae/{ => single-device}/rec_video.sh (75%) create mode 100644 examples/opensora_pku/scripts/causalvae/single-device/rec_video_folder.sh rename examples/opensora_pku/scripts/causalvae/{train_with_gan_loss.sh => single-device/train.sh} (91%) delete mode 100644 examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh delete mode 100644 examples/opensora_pku/scripts/text_condition/single-device/sample_debug.sh diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 0e65c062d2..1298598e47 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -8,7 +8,7 @@ Here we provide an efficient MindSpore version of [Open-Sora-Plan](https://githu | Official News from OpenSora-PKU | MindSpore Support | | ------------------ | ---------- | -| **[2024.10.16]** 🎉 PKU released version 1.3.0, featuring: **WFVAE**, **pompt refiner**, **data filtering strategy**, **sparse attention**, and **bucket training strategy**. They also support 93x480p within **24G VRAM**. More details can be found at their latest [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.3.0.md). | 📝 Working in Progress | +| **[2024.10.16]** 🎉 PKU released version 1.3.0, featuring: **WFVAE**, **pompt refiner**, **data filtering strategy**, **sparse attention**, and **bucket training strategy**. They also support 93x480p within **24G VRAM**. More details can be found at their latest [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.3.0.md). | ✅ V.1.3.0 WFVAE and OpenSoraT2V: inference, multi-stage & multi-devices training | | **[2024.07.24]** 🔥🔥🔥 PKU launched Open-Sora Plan v1.2.0, utilizing a 3D full attention architecture instead of 2+1D. See their latest [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.2.0.md). | ✅ V.1.2.0 CausalVAE inference & OpenSoraT2V multi-stage training| | **[2024.05.27]** 🚀🚀🚀 PKU launched Open-Sora Plan v1.1.0, which significantly improves video quality and length, and is fully open source! Please check out their latest [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.1.0.md). | ✅ V.1.1.0 CausalVAE inference and LatteT2V infernece & three-stage training (`65x512x512`, `221x512x512`, `513x512x512`) | | **[2024.04.09]** 🚀 PKU shared the latest exploration on metamorphic time-lapse video generation: [MagicTime](https://github.com/PKU-YuanGroup/MagicTime), and the dataset for train (updating): [Open-Sora-Dataset](https://github.com/PKU-YuanGroup/Open-Sora-Dataset).| N.A. | @@ -29,9 +29,9 @@ Here we provide an efficient MindSpore version of [Open-Sora-Plan](https://githu The following videos are generated based on MindSpore and Ascend 910*. -Open-Sora-Plan v1.2.0 Demo +Open-Sora-Plan v1.3.0 Demo -29×1280×720 Text-to-Video Generation. +93×640×640 Text-to-Video Generation. | 29x720x1280 (1.2s) | | --- | @@ -52,12 +52,13 @@ Videos are saved to `.gif` for display. ## 🔆 Features -- 📍 **Open-Sora-Plan v1.2.0** with the following features - - ✅ CausalVAEModel_D4_4x8x8 inference. Supports video reconstruction. +- 📍 **Open-Sora-Plan v1.3.0** with the following features + - ✅ WFVAE inference & multi-stage training. - ✅ mT5-xxl TextEncoder model inference. - - ✅ Text-to-video generation up to 93 frames and 720x1280 resolution. - - ✅ Multi-stage training using Zero2 and Sequence parallelism. - - ✅ Acceleration methods: flash attention, recompute (graident checkpointing), mixed precision, data parallelism, optimizer-parallel, etc.. + - ✅ Prompt Refiner. + - ✅ Text-to-video generation up to 93 frames and 640x640 resolution. + - ✅ Multi-stage training using Zero2 and sequence parallelism. + - ✅ Acceleration methods: flash attention, recompute (graident checkpointing), mixed precision, data parallelism, etc.. - ✅ Evaluation metrics : PSNR and SSIM. @@ -122,7 +123,7 @@ For EulerOS, instructions on ffmpeg and decord installation are as follows. ### Open-Sora-Plan v1.3.0 Model Weights -Please download the torch checkpoint of mT5-xxl from [google/mt5-xxl](https://huggingface.co/google/mt5-xxl/tree/main), and download the opensora v1.2.0 models' weights from [LanguageBind/Open-Sora-Plan-v1.3.0](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main). Place them under `examples/opensora_pku` as shown below: +Please download the torch checkpoint of mT5-xxl from [google/mt5-xxl](https://huggingface.co/google/mt5-xxl/tree/main), and download the opensora v1.3.0 models' weights from [LanguageBind/Open-Sora-Plan-v1.3.0](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main). Place them under `examples/opensora_pku` as shown below: ```bash mindone/examples/opensora_pku ├───LanguageBind @@ -155,7 +156,7 @@ Once the checkpoint files have all been prepared, you can refer to the inference ### CausalVAE Command Line Inference -You can run video-to-video reconstruction task using `scripts/causalvae/rec_video.sh`: +You can run video-to-video reconstruction task using `scripts/causalvae/single-device/rec_video.sh`: ```bash python examples/rec_video.py \ --ae "WFVAEModel_D8_4x8x8" \ @@ -164,72 +165,50 @@ python examples/rec_video.py \ --rec_path rec.mp4 \ --device Ascend \ --sample_rate 1 \ - --num_frames 61 \ + --num_frames 65 \ --height 512 \ --width 512 \ --fps 30 \ - --enable_tiling \ - --mode 1 \ + --enable_tiling ``` Please change the `--video_path` to the existing video file path and `--rec_path` to the reconstructed video file path. You can set `--grid` to save the original video and the reconstructed video in the same output file. -You can also run video reconstruction given an input video folder. See `scripts/causalvae/rec_video_folder.sh`. +You can also run video reconstruction given an input video folder. See `scripts/causalvae/single-device/rec_video_folder.sh`. ### Open-Sora-Plan v1.3.0 Command Line Inference -You need download the models manually. -First, you need to download checkpoint including [diffusion model](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/any93x640x640), [vae](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/vae) and [text encoder](https://huggingface.co/google/mt5-xxl), and optional [second text encoder](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k). The [prompt refiner](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.3.0/tree/main/prompt_refiner) is optional. - - - - -You can run text-to-video inference on a single Ascend device using the script `scripts/text_condition/single-device/sample_t2v_93x640.sh` by modifying `--model_path`, `--text_encoder_name_1` and `--ae_path`. The `--caption_refiner` and `--text_encoder_name_2` are optional. - - - +You can run text-to-video inference on a single Ascend device using the script `scripts/text_condition/single-device/sample_t2v_93x640.sh`. ```bash # Single NPU python opensora/sample/sample.py \ --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ --num_frames 93 \ - --height 352 \ + --height 640 \ --width 640 \ --text_encoder_name_1 google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ + --text_prompt examples/sora.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/prompt_list_0_93x640" \ + --save_img_path "./sample_videos/sora_93x640_mt5" \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ + --seed 1234 \ --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" \ - --mode 1 + --precision bf16 \ ``` -You can change the `num_frames`, `height` and `width`. -Note that DiT model is trained arbitrarily on stride=32. +You can change the `num_frames`, `height` and `width`. Note that DiT model is trained arbitrarily on stride=32. So keep the resolution of the inference a multiple of 32. `num_frames` needs to be 4n+1, e.g. 93, 77, 61, 45, 29, 1. - - -**To be revised.** - -If you want to run a multi-device inference, e.g., 8 cards, please use `msrun` and pass `--use_parallel=True` as the example below: - -```bash -# 8 NPUs -msrun --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="output_log" \ - python opensora/sample/sample_t2v.py \ - --use_parallel True \ - ... # pass other arguments -``` -The command above will run a 8-card inference and save the log files into "output_log". `--master_port` specifies the scheduler binding port number. `--worker_num` and `--local_worker_num` should be the same to the number of running devices, e.g., 8. +If you want to run a multi-device inference using data parallelism, please use `scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh`. +The script will run a 8-card inference and save the log files into "parallel_logs/". `--master_port` specifies the scheduler binding port number. `--worker_num` and `--local_worker_num` should be the same to the number of running devices, e.g., 8. In case of the following error: ```bash @@ -243,45 +222,24 @@ See more examples of multi-device inference scripts under `scripts/text_condifio ### Sequence Parallelism -We support running inference with sequence parallelism. Please see the `sample_t2v_29x480p_sp.sh` and `sample_t2v_29x720p_sp.sh` under `scripts/text_condition/multi-devices/`. - -If you set `--sp_size 8` to run sequence parallelism on 8 NPUs, you should also edit as follows: -```shell -export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +We support running inference with sequence parallelism. Please see the `sample_t2v_93x640_sp.sh` under `scripts/text_condition/multi-devices/`. The script will run a 8-card inference with `sp_size=8`, which means each video tensor is sliced into 8 parts along the sequence dimension. If you want to try `sp_size=4`, you can revise it as below: +```bash +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +msrun --bind_core=True --worker_num=4 --local_worker_num=4 --master_port=9000 --log_dir="./sample_videos/sora_93x640_mt5_sp/parallel_logs/" \ + opensora/sample/sample.py \ + ... \ + --sp_size 4 ``` + ## Training -### Causal Video VAE +### WFVAE #### Preparation **Step 1: Downloading Datasets**: -To train the causal vae model, you need to prepare a video dataset. Open-Sora-Plan-v1.2.0 trains vae in two stages. In the first stage, the authors trained vae on the Kinetic400 video dataset. Please download K400 dataset from [this repository](https://github.com/cvdfoundation/kinetics-dataset). In the second stage, they trained vae on Open-Sora-Dataset-v1.1.0 dataset. We give a tutorial on how to download the v1.1.0 datasets. See [downloading tutorial](./tools/download/README.md). - -**Step 2: Converting Pretrained Weights**: - -As with v1.1.0, they initialized from the [SD2.1 VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse) using tail initialization for better convergence. Please download the torch weight file from the given [URL](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main). - -After downloading the [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main) weights, you can run: -```bash -python tools/model_conversion/convert_vae_2d.py --src path/to/diffusion.safetensor --target /path/to/sd-vae-ft-mse.ckpt. -``` -This can convert the torch weight file into mindspore weight file. - -They you can inflate the 2d vae model checkpoint into a 3d causal vae initial weight file as follows: - -```bash -python tools/model_conversion/inflate_vae2d_to_vae3d.py \ - --src /path/to/sd-vae-ft-mse.ckpt \ - --target pretrained/causal_vae_488_init.ckpt -``` - -In order to train vae with lpips loss, please also download [lpips_vgg-426bf45c.ckpt](https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt) and put it under `pretrained/`. - -#### Standalone Training - -The first-stage training is conducted on 25-frame 256×256 videos of the [K400](https://github.com/cvdfoundation/kinetics-dataset) dataset. you can revise the `--video_path` in the training script to the video folder path of your downloaded dataset. This will allow the training script to load all video files under `video_path` in a **recursive manner**, and use them as the training data. Make sure the `--load_from_checkpoint` is set to the pretrained weight, e.g., `pretrained/causal_vae_488_init.ckpt`. +To train the causal vae model, you need to prepare a video dataset. Please download K400 dataset from [this repository](https://github.com/cvdfoundation/kinetics-dataset) as used in the [Arxiv paper](https://arxiv.org/abs/2411.17459) or download the UCF101 dataset from [the official website](https://www.crcv.ucf.edu/data/UCF101.php) as used in this tutorial.
@@ -306,62 +264,55 @@ python opensora/train/train_causalvae.py \ Similarly, you can create a csv file to include the test set videos, and pass the csv file to `--data_file_path` in `examples/rec_video_vae.py`.
-To launch a single-card training, please run: -```bash -bash scripts/causalvae/train_with_gan_loss.sh -``` +**Step 2: Prepare Pretrained Weights**: -> Note: -> - Supports resume training by setting `--resume_training_checkpoint True`. It is the same for the multi-device training script. +Open-Sora-Plan-v1.3.0 trains WFVAE in multiple stages. The loss used for the first two stages is a weighted sum of multiple loss terms: -#### Multi-Device Training +$L = L_{recon} + \lambda_{adv}L_{adv} + \lambda_{KL}L_{KL} + \lambda_{WL}L_{WL}$ -For parallel training, please use `msrun` and pass `--use_parallel=True`. -```bash -# 8 NPUs -msrun --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="output_log" \ - python opensora/train/train_causalvae.py \ - --use_parallel True \ - ... # pass other arguments, please refer to the single-device training script. -``` -For more details, please take `scripts/causalvae/train_with_gan_loss_multi_device.sh` as an example. +$L_{recon}$ represents the reconstruction loss (L1). $L_{adv}$ is the adversarial loss, and its weight $\lambda_{adv}$ is given by the argument `--disc_weight`. $L_{KL}$ is the KL divergence loss, and its weight $\lambda_{KL}$ is given by `--kl_weight`. $L_{WL}$ is the wavelet loss, and its weight $\lambda_{WL}$ is given by `--wavelet_weight`. In the third stage, LPIPS loss is also used to improve the performance. Its weight $\lambda_{lpips}$ is given by the argument `--perceptual_weight `. Please see more arguments in `opensora/train/train_causalvae.py`. + +In order to train vae with LPIPS loss, please also download [lpips_vgg-426bf45c.ckpt](https://download-mindspore.osinfra.cn/toolkits/mindone/autoencoders/lpips_vgg-426bf45c.ckpt) and put it under `pretrained/`. + +**Steps 3: Hyper-parameters Setting** + +Please find the hyper-parameters in each stage in the following table: +| Stage | Resolution | Num of frames | FPS | Batch size | Train Steps | Discrminator | $\lambda_{lpips}$ | +|:--- |:--- |:--- |:--- |:--- |:--- |:--- |:--- | +| 1 | 256x256 | 25 | Original fps | 8 | 800K | TRUE | - | +| 2 | 256x256 | 49 | Original fps / 2 | 8 | 200K | TRUE | - | +| 3 | 256x256 | 49 | Original fps / 2 | 8 | 200K | TRUE | 0.1 | +See the hyper-parameters in `scripts/causalvae/multi-devices/train_stage_x.sh` + +> Note: +> - We support resume training by setting `--resume_from_checkpoint True`. It is the same for the multi-device training script. +> - We also provide the standalone training script: `scripts/causalvae/single-device/train.sh`. #### Inference After Training -After training, you will find the checkpoint files under the `ckpt/` folder of the output directory. To evaluate the reconstruction of the checkpoint file, you can take `scripts/causalvae/rec_video_folder.sh` and revise it like: +After training, you will find the checkpoint files under the `ckpt/` folder of the output directory. To evaluate the reconstruction of the checkpoint file, you can take `scripts/causalvae/single-device/rec_video_folder.sh` and revise it like: ```bash python examples/rec_video_folder.py \ --batch_size 1 \ - --real_video_dir input_real_video_dir \ - --generated_video_dir output_generated_video_dir \ + --real_video_dir datasets/UCF-101/ \ + --data_file_path datasets/ucf101_test.csv \ + --generated_video_dir recons/ucf101_test/ \ --device Ascend \ - --sample_fps 10 \ + --sample_fps 30 \ --sample_rate 1 \ - --num_frames 65 \ - --height 480 \ - --width 640 \ + --num_frames 25 \ + --height 256 \ + --width 256 \ --num_workers 8 \ - --ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae \ - --enable_tiling \ - --save_memory \ - --ms_checkpoint /path/to/ms/checkpoint \ + --ae "WFVAEModel_D8_4x8x8" \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + --ms_checkpoint path/to/ms/ckpt \ ``` Runing this command will generate reconstructed videos under the given `output_generated_video_dir`. You can then evalute some common metrics (e.g., ssim, psnr) using the script under `opensora/eval/script`. - -#### Performance - -Here, we report the training performance and evaluation results on the UCF-101 dataset. Experiments are tested on Ascend 910* with mindspore 2.3.1 graph mode. - -| model name | cards | batch size | resolution | precision | discriminator |sink |recompute| jit level| graph compile | s/step | img/s | psnr | ssim | recipe| -|:-----------|:------ |:-----------:|:----------:|:-------------:|:----------:|:------------:|:---:|:--------:|:--------:|--------:|------:|:----:|-------:|-------:| -| CausalVAE | 8 | 1 | 25x256x256 | BF16 | FALSE | OFF | OFF | O0 | 3 mins | 4.21 | 47.51 | 28.92 | 0.87 | [train](./scripts/causalvae/train_without_gan_loss_multi_device.sh) | -| CausalVAE | 8 | 1 | 25x256x256 | FP32 | TRUE | OFF | ON | O0 | 3 mins | 5.45 | 36.70 | 29.28 | 0.88 | [train](./scripts/causalvae/train_with_gan_loss_multi_device.sh) | - - ### Training Diffusion Model #### Preparation diff --git a/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py b/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py index 7e09aee4fc..0bea0ce9fd 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/sample/rec_video_vae.py @@ -119,7 +119,7 @@ def main(args: argparse.Namespace): action="store_true", help="Whether to use a random frame as the starting frame for reconstruction. Default is False for the ease of evaluation.", ) - parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") parser.add_argument( "--precision", default="bf16", diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index 9b5969eea9..53f42d8309 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -639,7 +639,7 @@ def get_args(): # MS new args parser.add_argument("--device", type=str, default="Ascend", help="Ascend or GPU") parser.add_argument("--max_device_memory", type=str, default=None, help="e.g. `30GB` for 910a, `59GB` for 910b") - parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") + parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode") parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") parser.add_argument( "--parallel_mode", default="data", type=str, choices=["data", "optim"], help="parallel mode: data, optim" diff --git a/examples/opensora_pku/requirements.txt b/examples/opensora_pku/requirements.txt index f501cacbc2..987beb7b4f 100644 --- a/examples/opensora_pku/requirements.txt +++ b/examples/opensora_pku/requirements.txt @@ -13,10 +13,9 @@ safetensors omegaconf pyyaml sentencepiece -mindnlp==0.4.0 -bs4 -huggingface_hub>=0.22.2 -decord -pillow -tokenizers +beautifulsoup4 +huggingface_hub>=0.22.2,<0.26 transformers +tokenizers +pillow +decord diff --git a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh b/examples/opensora_pku/scripts/causalvae/multi-devices/rec_video_folder.sh similarity index 92% rename from examples/opensora_pku/scripts/causalvae/rec_video_folder.sh rename to examples/opensora_pku/scripts/causalvae/multi-devices/rec_video_folder.sh index 3ea7b087f5..968affaee2 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_video_folder.sh +++ b/examples/opensora_pku/scripts/causalvae/multi-devices/rec_video_folder.sh @@ -1,3 +1,4 @@ +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="parallel_logs/" examples/rec_video_folder.py \ --batch_size 1 \ --real_video_dir datasets/UCF-101/ \ @@ -12,7 +13,5 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --num_workers 8 \ --ae "WFVAEModel_D8_4x8x8" \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --mode 1 \ - --jit_syntax_level lax \ --use_parallel True \ # --ms_checkpoint path/to/ms/ckpt diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh b/examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_1.sh similarity index 91% rename from examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh rename to examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_1.sh index 3e349db12f..3ebd63ac50 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss_multi_device.sh +++ b/examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_1.sh @@ -3,7 +3,7 @@ export MS_ENABLE_NUMA=0 export MS_MEMORY_STATISTIC=1 export GLOG_v=2 output_dir="results/causalvae" -exp_name="25x256x256" +exp_name="stage1-25x256x256" msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=$output_dir/$exp_name/parallel_logs opensora/train/train_causalvae.py \ --exp_name $exp_name \ @@ -25,20 +25,17 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --betas 0.9 0.999 \ --clip_grad True \ --max_grad_norm 1.0 \ - --mode 1 \ --init_loss_scale 65536 \ --jit_level "O0" \ --use_discriminator True \ --use_parallel True \ --use_ema False \ --ema_decay 0.999 \ - --perceptual_weight 1.0 \ + --perceptual_weight 0.0 \ --loss_type l1 \ --sample_rate 1 \ --disc_cls causalvideovae.model.losses.LPIPSWithDiscriminator3D \ --disc_start 0 \ --wavelet_loss \ --wavelet_weight 0.1 \ - --mode 1 \ - --jit_syntax_level lax \ --print_losses diff --git a/examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_2.sh b/examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_2.sh new file mode 100644 index 0000000000..032bd27103 --- /dev/null +++ b/examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_2.sh @@ -0,0 +1,41 @@ +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export MS_ENABLE_NUMA=0 +export MS_MEMORY_STATISTIC=1 +export GLOG_v=2 +output_dir="results/causalvae" +exp_name="stage2-49x256x256" + +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=$output_dir/$exp_name/parallel_logs opensora/train/train_causalvae.py \ + --exp_name $exp_name \ + --model_name WFVAE \ + --model_config scripts/causalvae/wfvae_8dim.json \ + --train_batch_size 1 \ + --precision fp32 \ + --max_steps 100000 \ + --save_steps 2000 \ + --output_dir $output_dir \ + --video_path datasets/UCF-101 \ + --data_file_path datasets/ucf101_train.csv \ + --video_num_frames 49 \ + --resolution 256 \ + --dataloader_num_workers 8 \ + --start_learning_rate 1e-5 \ + --lr_scheduler constant \ + --optim adamw \ + --betas 0.9 0.999 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --init_loss_scale 65536 \ + --jit_level "O0" \ + --use_discriminator True \ + --use_parallel True \ + --use_ema False \ + --ema_decay 0.999 \ + --perceptual_weight 0.0 \ + --loss_type l1 \ + --sample_rate 2 \ + --disc_cls causalvideovae.model.losses.LPIPSWithDiscriminator3D \ + --disc_start 0 \ + --wavelet_loss \ + --wavelet_weight 0.1 \ + --print_losses diff --git a/examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_3.sh b/examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_3.sh new file mode 100644 index 0000000000..e3fe62865d --- /dev/null +++ b/examples/opensora_pku/scripts/causalvae/multi-devices/train_stage_3.sh @@ -0,0 +1,41 @@ +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export MS_ENABLE_NUMA=0 +export MS_MEMORY_STATISTIC=1 +export GLOG_v=2 +output_dir="results/causalvae" +exp_name="stage3-49x256x256" + +msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=$output_dir/$exp_name/parallel_logs opensora/train/train_causalvae.py \ + --exp_name $exp_name \ + --model_name WFVAE \ + --model_config scripts/causalvae/wfvae_8dim.json \ + --train_batch_size 1 \ + --precision fp32 \ + --max_steps 100000 \ + --save_steps 2000 \ + --output_dir $output_dir \ + --video_path datasets/UCF-101 \ + --data_file_path datasets/ucf101_train.csv \ + --video_num_frames 49 \ + --resolution 256 \ + --dataloader_num_workers 8 \ + --start_learning_rate 1e-5 \ + --lr_scheduler constant \ + --optim adamw \ + --betas 0.9 0.999 \ + --clip_grad True \ + --max_grad_norm 1.0 \ + --init_loss_scale 65536 \ + --jit_level "O0" \ + --use_discriminator True \ + --use_parallel True \ + --use_ema False \ + --ema_decay 0.999 \ + --perceptual_weight 0.1 \ + --loss_type l1 \ + --sample_rate 2 \ + --disc_cls causalvideovae.model.losses.LPIPSWithDiscriminator3D \ + --disc_start 0 \ + --wavelet_loss \ + --wavelet_weight 0.1 \ + --print_losses diff --git a/examples/opensora_pku/scripts/causalvae/release.json b/examples/opensora_pku/scripts/causalvae/release.json deleted file mode 100644 index c24b01fe3c..0000000000 --- a/examples/opensora_pku/scripts/causalvae/release.json +++ /dev/null @@ -1,85 +0,0 @@ -{ - "_class_name": "CausalVAEModel", - "_diffusers_version": "0.27.2", - "attn_resolutions": [], - "decoder_attention": "AttnBlock3DFix", - "decoder_conv_in": "CausalConv3d", - "decoder_conv_out": "CausalConv3d", - "decoder_mid_resnet": "ResnetBlock3D", - "decoder_resnet_blocks": [ - "ResnetBlock3D", - "ResnetBlock3D", - "ResnetBlock3D", - "ResnetBlock3D" - ], - "decoder_spatial_upsample": [ - "", - "SpatialUpsample2x", - "Spatial2xTime2x3DUpsample", - "Spatial2xTime2x3DUpsample" - ], - "decoder_spatial_upsample_unup": [ - "", - "", - "", - "" - ], - "decoder_temporal_upsample": [ - "", - "", - "", - "" - ], - "double_z": true, - "dropout": 0.0, - "embed_dim": 4, - "encoder_attention": "AttnBlock3DFix", - "encoder_conv_in": "Conv2d", - "encoder_conv_out": "CausalConv3d", - "encoder_mid_resnet": "ResnetBlock3D", - "encoder_resnet_blocks": [ - "ResnetBlock2D", - "ResnetBlock2D", - "ResnetBlock3D", - "ResnetBlock3D" - ], - "encoder_spatial_downsample": [ - "Downsample", - "Spatial2xTime2x3DDownsample", - "Spatial2xTime2x3DDownsample", - "" - ], - "encoder_spatial_downsample_undown": [ - "", - "", - "", - "" - ], - "encoder_temporal_downsample": [ - "", - "", - "", - "" - ], - "hidden_size": 128, - "hidden_size_mult": [ - 1, - 2, - 4, - 4 - ], - "in_channels": 3, - "loss_params": { - "disc_start": 1, - "disc_weight": 0.5, - "kl_weight": 1e-06, - "logvar_init": 0.0 - }, - "loss_type": "opensora.models.causalvideovae.model.losses.LPIPSWithDiscriminator", - "lr": 1e-05, - "num_res_blocks": 2, - "out_channels": 3, - "q_conv": "CausalConv3d", - "resolution": 256, - "z_channels": 4 -} diff --git a/examples/opensora_pku/scripts/causalvae/eval.sh b/examples/opensora_pku/scripts/causalvae/single-device/eval.sh similarity index 100% rename from examples/opensora_pku/scripts/causalvae/eval.sh rename to examples/opensora_pku/scripts/causalvae/single-device/eval.sh diff --git a/examples/opensora_pku/scripts/causalvae/rec_image.sh b/examples/opensora_pku/scripts/causalvae/single-device/rec_image.sh similarity index 83% rename from examples/opensora_pku/scripts/causalvae/rec_image.sh rename to examples/opensora_pku/scripts/causalvae/single-device/rec_image.sh index dfb5c6e08c..679642f189 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_image.sh +++ b/examples/opensora_pku/scripts/causalvae/single-device/rec_image.sh @@ -5,5 +5,3 @@ python examples/rec_image.py \ --rec_path rec.jpg \ --device Ascend \ --short_size 512 \ - --mode 1 \ - --jit_syntax_level lax \ diff --git a/examples/opensora_pku/scripts/causalvae/rec_video.sh b/examples/opensora_pku/scripts/causalvae/single-device/rec_video.sh similarity index 75% rename from examples/opensora_pku/scripts/causalvae/rec_video.sh rename to examples/opensora_pku/scripts/causalvae/single-device/rec_video.sh index 244f7acc37..0d6e1f2b38 100644 --- a/examples/opensora_pku/scripts/causalvae/rec_video.sh +++ b/examples/opensora_pku/scripts/causalvae/single-device/rec_video.sh @@ -5,10 +5,8 @@ python examples/rec_video.py \ --rec_path rec.mp4 \ --device Ascend \ --sample_rate 1 \ - --num_frames 61 \ + --num_frames 65 \ --height 512 \ --width 512 \ --fps 30 \ - --enable_tiling \ - --mode 1 \ - --jit_syntax_level lax \ + --enable_tiling diff --git a/examples/opensora_pku/scripts/causalvae/single-device/rec_video_folder.sh b/examples/opensora_pku/scripts/causalvae/single-device/rec_video_folder.sh new file mode 100644 index 0000000000..4bda355614 --- /dev/null +++ b/examples/opensora_pku/scripts/causalvae/single-device/rec_video_folder.sh @@ -0,0 +1,15 @@ +python examples/rec_video_folder.py \ + --batch_size 1 \ + --real_video_dir datasets/UCF-101/ \ + --data_file_path datasets/ucf101_test.csv \ + --generated_video_dir recons/ucf101_test/ \ + --device Ascend \ + --sample_fps 30 \ + --sample_rate 1 \ + --num_frames 25 \ + --height 256 \ + --width 256 \ + --num_workers 8 \ + --ae "WFVAEModel_D8_4x8x8" \ + --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ + # --ms_checkpoint path/to/ms/ckpt diff --git a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh b/examples/opensora_pku/scripts/causalvae/single-device/train.sh similarity index 91% rename from examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh rename to examples/opensora_pku/scripts/causalvae/single-device/train.sh index 37c506f752..659cc0fda5 100644 --- a/examples/opensora_pku/scripts/causalvae/train_with_gan_loss.sh +++ b/examples/opensora_pku/scripts/causalvae/single-device/train.sh @@ -18,19 +18,16 @@ python opensora/train/train_causalvae.py \ --betas 0.9 0.999 \ --clip_grad True \ --max_grad_norm 1.0 \ - --mode 1 \ --init_loss_scale 65536 \ --jit_level "O0" \ --use_discriminator True \ --use_ema False \ --ema_decay 0.999 \ - --perceptual_weight 1.0 \ + --perceptual_weight 0.0 \ --loss_type l1 \ --sample_rate 1 \ --disc_cls causalvideovae.model.losses.LPIPSWithDiscriminator3D \ --disc_start 0 \ --wavelet_loss \ --wavelet_weight 0.1 \ - --mode 1 \ - --jit_syntax_level lax \ --print_losses diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh index 250f509792..5df1976270 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh @@ -21,6 +21,5 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" \ - --mode 1 \ --precision bf16 \ --use_parallel True \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh index 7ecd0da215..da48064235 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh @@ -1,4 +1,3 @@ - export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./sample_videos/sora_93x640_mt5_sp/parallel_logs/" \ opensora/sample/sample.py \ @@ -22,7 +21,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" \ - --mode 1 \ --precision bf16 \ --use_parallel True \ --sp_size 8 diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh deleted file mode 100644 index d79fb40d14..0000000000 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_debug.sh +++ /dev/null @@ -1,61 +0,0 @@ -# Stage 2: 93x320x320 -NUM_FRAME=29 -WIDTH=320 -HEIGHT=320 -ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 -msrun --bind_core=True --worker_num=4 --local_worker_num=4 --master_port=6000 --log_dir="./checkpoints/t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}_zero2_mode1_npu4/parallel_logs" \ - opensora/train/train_t2v_diffusers.py \ - --model OpenSoraT2V_v1_3-2B/122 \ - --text_encoder_name_1 /home_host/susan/workspace/checkpoints/google/mt5-xxl \ - --cache_dir "./" \ - --dataset t2v \ - --data "scripts/train_data/video_data_v1_2.txt" \ - --ae WFVAEModel_D8_4x8x8 \ - --ae_path /home_host/susan/workspace/checkpoints/LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --sample_rate 1 \ - --num_frames ${NUM_FRAME} \ - --max_height ${HEIGHT} \ - --max_width ${WIDTH} \ - --interpolation_scale_t 1.0 \ - --interpolation_scale_h 1.0 \ - --interpolation_scale_w 1.0 \ - --gradient_checkpointing \ - --train_batch_size=1 \ - --dataloader_num_workers 4 \ - --gradient_accumulation_steps=1 \ - --max_train_steps=1000000 \ - --start_learning_rate=2e-5 \ - --lr_scheduler="constant" \ - --seed=10 \ - --lr_warmup_steps=500 \ - --precision="bf16" \ - --checkpointing_steps=1000 \ - --output_dir="./checkpoints/t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}_zero2_mode1_npu4/" \ - --model_max_length 512 \ - --use_image_num 0 \ - --cfg 0.1 \ - --snr_gamma 5.0 \ - --use_ema False \ - --ema_start_step 0 \ - --enable_tiling \ - --clip_grad True \ - --max_grad_norm 1.0 \ - --noise_offset 0.02 \ - --ema_decay 0.999 \ - --speed_factor 1.0 \ - --drop_short_ratio 1.0 \ - --use_parallel True \ - --parallel_mode "zero" \ - --zero_stage 2 \ - --max_device_memory "58GB" \ - --jit_syntax_level "lax" \ - --dataset_sink_mode True \ - --num_no_recompute 18 \ - --prediction_type "v_prediction" \ - --hw_stride 32 \ - --sparse1d \ - --sparse_n 4 \ - --train_fps 16 \ - --trained_data_global_step 0 \ - --group_data \ - --mode 1 diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh index 71712662c1..73dac17ebf 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh @@ -50,7 +50,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --jit_syntax_level "lax" \ --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ @@ -59,4 +58,3 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ - --mode 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index 7ac9fc44b5..4e6eb05c7c 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -51,7 +51,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --jit_syntax_level "lax" \ --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ @@ -60,6 +59,5 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ - --mode 1 \ --sp_size 8 \ --train_sp_batch_size 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index 05d544f23c..6029a0a8ca 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -50,7 +50,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --jit_syntax_level "lax" \ --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ @@ -59,6 +58,5 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ - --mode 1 \ # --sp_size 8 \ # --train_sp_batch_size 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_debug.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_debug.sh deleted file mode 100644 index b4cc745ef0..0000000000 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_debug.sh +++ /dev/null @@ -1,36 +0,0 @@ -# Quick debug for DiT config: -# - 1 NPU/GPU -# - fewer frames: 29 -# - small and uncommon resolution: 352x640 -# - fps: 24 -# - precision: bf16 (Some exceptions ref to sample_utils.py. Torch ver doesn't share, always uses fp16. ) - -# Debug first prompt only: -# "A young man at his 20s is sitting on a piece of cloud in the sky, reading a book." - -# To use: -# change model_path, text_encoder_name_1, ae_path, save_img_path before running the script. - -export DEVICE_ID=0 -python opensora/sample/sample.py \ - --model_path /home_host/susan/workspace/checkpoints/LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ - --version v1_3 \ - --num_frames 29 \ - --height 352 \ - --width 640 \ - --text_encoder_name_1 /home_host/susan/workspace/checkpoints/google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ - --ae WFVAEModel_D8_4x8x8 \ - --ae_path /home_host/susan/workspace/checkpoints/LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/prompt_list_0_29x640_mt5_bf16_debug" \ - --fps 24 \ - --guidance_scale 7.5 \ - --num_sampling_steps 100 \ - --enable_tiling \ - --max_sequence_length 512 \ - --sample_method EulerAncestralDiscrete \ - --seed 1234 \ - --num_samples_per_prompt 1 \ - --rescale_betas_zero_snr \ - --prediction_type "v_prediction" \ - --mode 1 --precision bf16 \ No newline at end of file diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh index 29b6489612..690caf8c2a 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh @@ -23,6 +23,5 @@ python opensora/sample/sample.py \ --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" \ - --mode 1 \ --precision bf16 \ --ms_checkpoint ckpt/path \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh index 3ebdb7be9c..d8b54a5727 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh @@ -23,5 +23,4 @@ python opensora/sample/sample.py \ --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" \ - --mode 1 \ --precision bf16 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh index 5794be27e1..54b3db8e30 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh @@ -23,4 +23,3 @@ python opensora/sample/sample.py \ --num_samples_per_prompt 1 \ --rescale_betas_zero_snr \ --prediction_type "v_prediction" \ - --mode 1 diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh index 9e6bb6b1fb..99dccfa1c3 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh @@ -44,7 +44,6 @@ python opensora/train/train_t2v_diffusers.py \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ --max_device_memory "59GB" \ - --jit_syntax_level "lax" \ --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ @@ -53,4 +52,3 @@ python opensora/train/train_t2v_diffusers.py \ --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ - --mode 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh index e380c1704c..00f3c70ce7 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh @@ -44,7 +44,6 @@ python opensora/train/train_t2v_diffusers.py \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ --max_device_memory "59GB" \ - --jit_syntax_level "lax" \ --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ @@ -53,4 +52,3 @@ python opensora/train/train_t2v_diffusers.py \ --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ - --mode 1 diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh index 1b7105f3fe..3fd7b55de3 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh @@ -44,7 +44,6 @@ python opensora/train/train_t2v_diffusers.py \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ --max_device_memory "59GB" \ - --jit_syntax_level "lax" \ --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ @@ -53,4 +52,3 @@ python opensora/train/train_t2v_diffusers.py \ --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ - --mode 1 From fa850d0b98467aeba0e52613acf070de3dcfdcab Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:51:26 +0800 Subject: [PATCH 105/133] update t2v training --- examples/opensora_pku/README.md | 97 ++++++++----------- .../multi-devices/train_t2i_stage1.sh | 1 + .../multi-devices/train_t2v_stage2.sh | 1 + .../multi-devices/train_t2v_stage3.sh | 3 +- .../single-device/sample_t2i_1x320x320.sh | 2 +- .../single-device/train_t2i_stage1.sh | 4 +- .../single-device/train_t2v_stage2.sh | 14 ++- .../single-device/train_t2v_stage3.sh | 8 +- 8 files changed, 61 insertions(+), 69 deletions(-) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 1298598e47..e722a5fb5a 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -276,7 +276,7 @@ In order to train vae with LPIPS loss, please also download [lpips_vgg-426bf45c. **Steps 3: Hyper-parameters Setting** -Please find the hyper-parameters in each stage in the following table: +As introduced in the [Open-Sora Plan Arxiv paper](https://arxiv.org/abs/2412.00131), the hyper-parameters of each stage is summerized in the following table: | Stage | Resolution | Num of frames | FPS | Batch size | Train Steps | Discrminator | $\lambda_{lpips}$ | |:--- |:--- |:--- |:--- |:--- |:--- |:--- |:--- | | 1 | 256x256 | 25 | Original fps | 8 | 800K | TRUE | - | @@ -319,8 +319,7 @@ Runing this command will generate reconstructed videos under the given `output_g **Step 1: Downloading Datasets**: - -The [Open-Sora-Dataset-v1.2.0](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) contains annotation json files, which are listed below: +Open-Sora-Dataset-v1.3.0 dataset is the same as the dataset used in [Open-Sora-Dataset-v1.2.0](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) which contains annotation json files listed below: ```text Panda70M_HQ1M.json @@ -331,7 +330,7 @@ v1.1.0_HQ_part2.json v1.1.0_HQ_part3.json ``` -Please check the [readme doc](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) for details of these annotation files. [Open-Sora-Dataset-v1.2.0](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) contains the [Panda70M (training full)](https://drive.google.com/file/d/1DeODUcdJCEfnTjJywM-ObmrlVg-wsvwz/view?usp=sharing), [SAM](https://ai.meta.com/datasets/segment-anything/), and the data from [Open-Sora-Dataset-v1.1.0](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main). You can take the following instructions only how to download [Open-Sora-Dataset-v1.1.0](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main). +Please check the [readme doc](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) for details of these annotation files. [Open-Sora-Dataset-v1.2.0](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.2.0) contains the [Panda70M (training full)](https://drive.google.com/file/d/1DeODUcdJCEfnTjJywM-ObmrlVg-wsvwz/view?usp=sharing), [SAM](https://ai.meta.com/datasets/segment-anything/) and the data from [Open-Sora-Dataset-v1.1.0](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main). You can take the following instructions on how to download [Open-Sora-Dataset-v1.1.0](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.1.0/tree/main).
@@ -399,7 +398,7 @@ anno_jsons/ **Step 2: Extracting Embedding Cache**: -Next, please extract the text embeddings and save them in the disk for training acceleration. For each json file, you need to run the following command accordingly and save the t5 embeddings cache in the `output_path`. +Next, please extract the text embeddings and save them in the disk for training acceleration. For each json file, you need to run the following command accordingly and save the mt5-xxl embeddings cache in the `output_path`. ```bash python opensora/sample/sample_text_embed.py \ @@ -419,24 +418,37 @@ annotation json path: e.g., datasets/anno_jsons/Panda70M_HQ1M.json ``` In the dataset file, for example, `scripts/train_data/merge_data.txt`, each line represents one dataset. Each line includes three paths: the images/videos folder, the text embedding cache folder, and the path to the annotation json file. Please revise them accordingly to the paths on your disk. +**Step 4: Hyper-Parameters Setting** + + +As introduced in the [Open-Sora Plan Arxiv paper](https://arxiv.org/abs/2412.00131), the hyper-parameters of each stage is summerized in the following table: + +| Stage | Resolution | Num of frames | Datasets | Batch size | Train Steps | LR | Attention | +|:--- |:--- |:--- |:--- |:--- |:--- |:--- |:--- | +| 1 (T2I) | 320x320 | 1 | SAM, AnyText, Human Images | 1024 | 150K (full-attention) + 100K (skiparse attention) | 2e-5 | Full 3D -> Skiparse | +| 2 (T2I&T2V) | maximumly 93×640×640 | 49 | SAM, Panda70M | 1024 | 200K | 2e-5 | Skiparse | +| 3 (T2V) | 93x352x640 | 49 | filtered Panda70M, high-quality data | 1024 | 100K~200K | 1e-5 | Skiparse | + #### Example of Training Scripts The training scripts are stored under `scripts/text_condition`. The single-device training scripts are under the `single-device` folder for demonstration. We recommend to use the parallel-training scripts under the `multi-devices` folder. -Here we choose an example of training scripts (`train_video3d_nx480p_zero2.sh`) and explain the meanings of some experimental arguments. +Here we choose an example of training scripts (`train_t2i_stage1.sh`) and explain the meanings of some experimental arguments. Here is the major command of the training script: ```shell -NUM_FRAME=29 -python opensora/train/train_t2v_diffusers.py \ - --data "scripts/train_data/merge_data.txt" \ +NUM_FRAME=1 +WIDTH=320 +HEIGHT=320 +python opensora/train/train_t2v_diffusers.py \ + --data "scripts/train_data/image_data_v1_2.txt" \ --num_frames ${NUM_FRAME} \ - --max_height 480 \ - --max_width 640 \ - --attention_mode xformers \ + --force_resolution \ + --max_height ${HEIGHT} \ + --max_width ${WIDTH} \ --gradient_checkpointing \ - --pretrained "path/to/ms-or-safetensors-ckpt/from/last/stage" \ + --pretrained path/to/last/stage/ckpt \ --parallel_mode "zero" \ --zero_stage 2 \ # pass other arguments @@ -445,13 +457,13 @@ There are some arguments related to the training dataset path: - `data`: the text file to the video/image dataset. The text file should contain N lines corresponding to N datasets. Each line should have two or three items. If two items are available, they correspond to the video folder and the annotation json file. If three items are available, they correspond to the video folder, the text embedding cache folder, and the annotation json file. - `num_frames`: the number of frames of each video sample. - `max_height` and `max_width`: the frame maximum height and width. -- `attention_mode`: the attention mode, choosing from `math` or `xformers`. Note that we are not using the actual [xformers](https://link.zhihu.com/?target=https%3A//github.com/facebookresearch/xformers) library to accelerate training, but using MindSpore-native `FlashAttentionScore`. The `xformers` is kept for compatibility and maybe discarded in the future. +- `force_resolution`: whether to train with fixed resolution or dynamic resolution. If `force_resolution` is True, all videos will be cropped and resized to the resolution of `args.max_height x args.max_width`. If `force_resolution` is False, `args.max_hxw` must be provided which determines the maximum token length of each video tensor. - `gradient_checkpointing`: it is referred to MindSpore [recomputation](https://www.mindspore.cn/docs/en/r2.3.1/api_python/mindspore/mindspore.recompute.html) feature, which can save memory by recomputing the intermediate activations in the backward pass. -- `pretrained`: the pretrained checkpoint to be loaded as initial weights before training. If not provided, the OpenSoraT2V will use random initialization. If provided, the path should be either the safetensors checkpoint directiory or path, e.g., "LanguageBind/Open-Sora-Plan-v1.2.0/1x480p" or "LanguageBind/Open-Sora-Plan-v1.2.0/1x480p/diffusion_pytorch_model.safetensors", or MindSpore checkpoint path, e.g., "t2i-image3d-1x480p/ckpt/OpenSoraT2V-ROPE-L-122.ckpt". +- `pretrained`: the pretrained checkpoint to be loaded as initial weights before training. If not provided, the OpenSoraT2V will use random initialization. - `parallel_mode`: the parallelism mode chosen from ["data", "optim", "zero"], which denotes the data parallelism, the optimizer parallelism and the deepspeed zero_x parallelism. -- `zero_stage`: runs parallelism like deepspeed, supporting zero0, zero1, zero2, and zero3, if parallel_mode is "zero". +- `zero_stage`: runs parallelism like deepspeed, supporting zero0, zero1, zero2, and zero3, if parallel_mode is "zero". By default, we use `--zero_stage 2` for all training stages. -For the stage 4 (`29x720p`) and stage 5 (`93x720p`) training script, please refer to `train_video3d_29x720p_zero2_sp.sh` and `train_video3d_93x720p_zero2_sp.sh`. +For the stage 2 and stage 3 training scripts, please refer to `train_t2v_stage2.sh` and `train_t2v_stage3.sh`. #### Validation During Training @@ -464,56 +476,23 @@ We also support to run validation during training. This is supported by editing + --val_batch_size 1 \ + --val_interval 1 \ ``` -The edits allow to compute the loss on the validation set specified by `merge_data_val.txt` for every 1 epoch (defined by `val_interval`). `merge_data_val.txt` has the same format as `merge_data_train.txt`, but specifies a different subset from the train set. The validation loss will be recorded in the `result_val.log` under the output directory. For example training script, please refer to `train_video3d_29x720p_zero2_sp_val.sh` under `scripts/text_conditions/multi-devices/`. - +The edits allow to compute the loss on the validation set specified by `merge_data_val.txt` for every 1 epoch (defined by `val_interval`). `merge_data_val.txt` has the same format as `merge_data_train.txt`, but specifies a different subset from the train set. The validation loss will be recorded in the `result_val.log` under the output directory. #### Sequence Parallelism -We also support training with sequence parallelism and zero2 parallelism together. This is enabled by setting `--sp_size`. For example, with `sp_size=8`, 8 NPUs are used for a single video sample. - -See `train_video3d_29x720p_zero2_sp.sh` under `scripts/text_condition/mult-devices/` for detailed usage. - -#### Multi-node Training +We also support training with sequence parallelism and zero2 parallelism together. This is enabled by setting `--sp_size`. -When training on NPU clusters, you may need to train with multiple nodes and multiple devices. Here we provide some example scripts for training on 2 nodes, 16 NPUs. See `train_video3d_nx480p_zero2_multi_node.sh` and `train_video3d_29x720p_zero2_sp_multi_node.sh` under `scripts/text_condition/mult-devices/` for detailed usage. - -The major differences between the single-node training script and the multi-node training sccript are as follows: -```bash -MS_WORKER_NUM=16 # the total number of workers in all nodes -LOCAL_WORKER_NUM=8 # the number of workers in the current node -NODE_RANK=$1 # the ID of the current node, pass it via `bash xxx.sh 0` or `bash xxx.sh 1` -MASTER_NODE_ADDRESS="x.xxx.xxx.xxx" # the address of the master node. Use the same master address in two nodes -``` -`MS_WORKER_NUM` means the total number of workers in the two nodes, which is 16. `LOCAL_WORKER_NUM` is the number of workers in the current node, which is 8, since we have 8 NPUs in each node. `NODE_RANK` is the rank ID of each node. By default, the master node's rank id is 0, and the other node's rank id is 1. You can set the node rank id by using `bash xxx.sh 0` or `bash xxx.sh 1`. Finally, `MASTER_NODE_ADDRESS` is the address of the master node and please edit it to your master **server address**. - -Suppose we have two nodes: node_0 and node_1. Each node has 8 NPUs. Please follow the steps below to launch a two-node training of stage 3 using `train_video3d_nx480p_zero2_multi_node.sh`: -> 1. Prepare the datasets and edit the `merge_data.txt` on the two nodes following the instructions of [Sec. Preparation](./README.md#preparation-1). -> 2. Edit the `MASTER_NODE_ADDRESS` in `train_video3d_nx480p_zero2_multi_node.sh` on both node_0 and node_1. `MASTER_NODE_ADDRESS` should be the server address of node_0. You should use the same master address in two nodes. -> 3. In the master node, run `bash train_video3d_nx480p_zero2_multi_node.sh 0`, and in the other node, run `bash train_video3d_nx480p_zero2_multi_node.sh 1`. This the **only difference** between the training scripts on the two nodes. - -#### Tips on Finetuning - -To align with the hyper-parameters, we use the same learning rate (LR) $1e^{-4}$ as [Open-Sora-Plan v1.2.0](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/v1.2.0). However, our experience indicates that $1e^{-4}$ might be too large for finetuning the model on a small training set. If you want to finetune Open-Sora-Plan on your custom data with a small size, and notice that the large LR leads to unstable training, we have a few tips for you: - -1. You can lower your LR or increase the effective batch size, for example, by increasing `gradient_accumulation_steps` or running multi-machine training. -2. You can try a different LR scheduler, for example, you can change the current constant LR scheduler to `polynomial decay` by: -```diff -- --lr_scheduler="constant" \ -+ --lr_scheduler="polynomial_decay" \ -+ --lr_decay_steps=1000000 \ -``` -The edits will set the polynomial_decay LR scheduler, and decay the start LR to the end LR in 1000000 steps. You can adjust `lr_decay_steps` based on your `max_train_steps`. See other options of LR scheduler in `mindone/trainers/lr_schedule.py`. +See `train_t2v_stage2.sh` under `scripts/text_condition/mult-devices/` for detailed usage. #### Performance -The training performance are tested on ascend 910* with mindspore 2.3.1 graph mode. The results are as follows. +We evaluated the training performance on Ascend NPUs. The results are as follows. -| model name | cards | stage |batch size | num frames| resolution | graph compile | parallelism | recompute | sink | jit level| s/step | img/s | -|:----------------|:----------- |:----------|:---------:|:-----:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|-------------------:|:----------:| -| OpenSoraT2V-ROPE-L-122 | 8 | 2 | 8 | 1 | 640x480 | 3mins | zero2 | ON | ON | O0 | 2.35 | 27.23 | -| OpenSoraT2V-ROPE-L-122 | 8 | 3 | 1 | 29 | 640x480 | 6mins | zero2 | ON | ON | O0 | 3.68 | 63.04 | -| OpenSoraT2V-ROPE-L-122 | 8 | 4 | 1 | 29 |1280x720 | 10mins | zero2 + SP(sp_size=8) | OFF | ON | O0 | 4.32 | 6.71 | -| OpenSoraT2V-ROPE-L-122 | 8 | 5 | 1 | 93 | 1280x720 | 15mins | zero2 + SP(sp_size=8) | ON | ON | O0 | 24.40 | 3.81 | +| model name | cards | stage |graph compile | batch size (local) | video size | Paramllelism |recompute |data sink | jit level| step time | train imgs/s | +|:----------------|:----------- |:----------|:---------:|:-----:|:----------:|:----------:|:----------:|:----------:|:----------:|-------------------:|:----------:| +| OpenSoraT2V_v1_3-2B/122 | 8 | 1 | mins | 8 | 1x320x320 | zero2 | TRUE | FALSE | O0 | 5.1s | | +| OpenSoraT2V_v1_3-2B/122 | 8 | 2 | mins | 1 | up to 93x640x640 | zero2 + SP(sp_size=8) | TRUE | FALSE | O0 | | | +| OpenSoraT2V_v1_3-2B/122 | 8 | 3 | mins | 8 | 93x320x320 | zero2 | TRUE | FALSE | O0 | | | > SP: sequence parallelism. > diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh index 73dac17ebf..f346460583 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh @@ -58,3 +58,4 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ + --pretrained path/to/last/stage/ckpt \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index 4e6eb05c7c..463407cdee 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -61,3 +61,4 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --group_data \ --sp_size 8 \ --train_sp_batch_size 1 \ + --pretrained path/to/last/stage/ckpt \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index 6029a0a8ca..a857e3be8a 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -1,4 +1,4 @@ -# Stage 3: 93x480x480 (480x480, 640x352, 352x640) +# Stage 3: 93x640x352 (480x480, 640x352, 352x640) NUM_FRAME=93 WIDTH=640 HEIGHT=352 @@ -58,5 +58,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ + --pretrained path/to/last/stage/ckpt \ # --sp_size 8 \ # --train_sp_batch_size 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh index 690caf8c2a..f21ac031ee 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh @@ -12,7 +12,7 @@ python opensora/sample/sample.py \ --text_prompt examples/prompt_list_human_images.txt \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ - --save_img_path "./sample_videos/human_images" \ + --save_img_path "./sample_images/human_images" \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh index 99dccfa1c3..9d0d38819b 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh @@ -7,11 +7,12 @@ python opensora/train/train_t2v_diffusers.py \ --text_encoder_name_1 google/mt5-xxl \ --cache_dir "./" \ --dataset t2v \ - --data "scripts/train_data/merge_data_human_image.txt" \ + --data "scripts/train_data/image_data_v1_2.txt" \ --ae WFVAEModel_D8_4x8x8 \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --sample_rate 1 \ --num_frames ${NUM_FRAME} \ + --force_resolution \ --max_height ${HEIGHT} \ --max_width ${WIDTH} \ --interpolation_scale_t 1.0 \ @@ -52,3 +53,4 @@ python opensora/train/train_t2v_diffusers.py \ --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ + --pretrained path/to/last/stage/ckpt \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh index 00f3c70ce7..d22465b2aa 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh @@ -1,8 +1,9 @@ -# Stage 2: 93x320x320 +# Stage 2: 93x640x640 NUM_FRAME=93 -WIDTH=320 -HEIGHT=320 -python opensora/train/train_t2v_diffusers.py \ +WIDTH=640 +HEIGHT=640 +MAX_HxW=409600 +python opensora/train/train_t2v_diffusers.py \ --model OpenSoraT2V_v1_3-2B/122 \ --text_encoder_name_1 google/mt5-xxl \ --cache_dir "./" \ @@ -14,6 +15,7 @@ python opensora/train/train_t2v_diffusers.py \ --num_frames ${NUM_FRAME} \ --max_height ${HEIGHT} \ --max_width ${WIDTH} \ + --max_hxw ${MAX_HxW} \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ @@ -43,6 +45,9 @@ python opensora/train/train_t2v_diffusers.py \ --ema_decay 0.999 \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ + --use_parallel False \ + --parallel_mode "zero" \ + --zero_stage 2 \ --max_device_memory "59GB" \ --dataset_sink_mode False \ --prediction_type "v_prediction" \ @@ -52,3 +57,4 @@ python opensora/train/train_t2v_diffusers.py \ --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ + --pretrained path/to/last/stage/ckpt \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh index 3fd7b55de3..7107bd6a8e 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh @@ -1,7 +1,7 @@ -# Stage 3: 93x480x480 (480x480, 640x352, 352x640) +# Stage 3: 93x640x352 (480x480, 640x352, 352x640) NUM_FRAME=93 -WIDTH=480 -HEIGHT=480 +WIDTH=640 +HEIGHT=352 python opensora/train/train_t2v_diffusers.py \ --model OpenSoraT2V_v1_3-2B/122 \ --text_encoder_name_1 google/mt5-xxl \ @@ -12,6 +12,7 @@ python opensora/train/train_t2v_diffusers.py \ --ae_path LanguageBind/Open-Sora-Plan-v1.3.0/vae \ --sample_rate 1 \ --num_frames ${NUM_FRAME} \ + --force_resolution \ --max_height ${HEIGHT} \ --max_width ${WIDTH} \ --interpolation_scale_t 1.0 \ @@ -52,3 +53,4 @@ python opensora/train/train_t2v_diffusers.py \ --train_fps 16 \ --trained_data_global_step 0 \ --group_data \ + --pretrained path/to/last/stage/ckpt \ From 5573d60a1ffe261682605981e793cf490336a340 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 16 Dec 2024 17:34:37 +0800 Subject: [PATCH 106/133] update sample shape to 352x640 --- examples/opensora_pku/README.md | 27 +------------------ .../multi-devices/sample_t2v_93x640_ddp.sh | 2 +- .../multi-devices/sample_t2v_93x640_sp.sh | 2 +- .../single-device/sample_t2v_93x640.sh | 2 +- 4 files changed, 4 insertions(+), 29 deletions(-) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index e722a5fb5a..13586c3b07 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -2,7 +2,7 @@ Here we provide an efficient MindSpore version of [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/main) from Peking University. We would like to express our gratitude to their contributions! :+1: -**OpenSora-PKU is still under active development.** Currently, we are in line with **Open-Sora-Plan v1.2.0** ([commit id](https://github.com/PKU-YuanGroup/Open-Sora-Plan/commit/294993ca78bf65dec1c3b6fb25541432c545eda9)). +**OpenSora-PKU is still under active development.** Currently, we are in line with **Open-Sora-Plan v1.3.0** ([commit id](https://github.com/PKU-YuanGroup/Open-Sora-Plan/commit/9fa322fbbb276e2bbe40b2f439e3d610af3d7690)). ## 📰 News & States @@ -25,31 +25,6 @@ Here we provide an efficient MindSpore version of [Open-Sora-Plan](https://githu | :---: | :---: | :---: | :---: | | 2.3.1 | 24.1RC2 |7.3.0.1.231| 8.0.RC2.beta1 | -## 🎥 Demo - -The following videos are generated based on MindSpore and Ascend 910*. - -Open-Sora-Plan v1.3.0 Demo - -93×640×640 Text-to-Video Generation. - -| 29x720x1280 (1.2s) | -| --- | -| | -| A close-up of a woman’s face, illuminated by the soft light of dawn... | - -| 29x720x1280 (1.2s) | -| --- | -| | -| 0-A young man at his 20s is sitting on a piece of cloud in the sky, reading a book... | - -| 29x720x1280 (1.2s) | -| --- | -| | -| 0-A close-up of a woman with a vintage hairstyle and bright red lipstick... | - -Videos are saved to `.gif` for display. - ## 🔆 Features - 📍 **Open-Sora-Plan v1.3.0** with the following features diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh index 5df1976270..c7ecbb55c7 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh @@ -4,7 +4,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ --num_frames 93 \ - --height 640 \ + --height 352 \ --width 640 \ --text_encoder_name_1 google/mt5-xxl \ --text_prompt examples/sora.txt \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh index da48064235..d5e0e1a5c3 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh @@ -4,7 +4,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ --num_frames 93 \ - --height 640 \ + --height 352 \ --width 640 \ --text_encoder_name_1 google/mt5-xxl \ --text_prompt examples/sora.txt \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh index d8b54a5727..e55d833bb9 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh @@ -6,7 +6,7 @@ python opensora/sample/sample.py \ --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ --num_frames 93 \ - --height 640 \ + --height 352 \ --width 640 \ --text_encoder_name_1 google/mt5-xxl \ --text_prompt examples/sora.txt \ From b702e54ab0fcade7c3a1246bb693180cf0bb49f4 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 16 Dec 2024 17:40:19 +0800 Subject: [PATCH 107/133] update readmd --- examples/opensora_pku/README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 13586c3b07..fc5e937e52 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -461,13 +461,13 @@ See `train_t2v_stage2.sh` under `scripts/text_condition/mult-devices/` for detai #### Performance -We evaluated the training performance on Ascend NPUs. The results are as follows. +We evaluated the training performance on Ascend NPUs. All experiments are running in PYNATIVE mode. The results are as follows. -| model name | cards | stage |graph compile | batch size (local) | video size | Paramllelism |recompute |data sink | jit level| step time | train imgs/s | -|:----------------|:----------- |:----------|:---------:|:-----:|:----------:|:----------:|:----------:|:----------:|:----------:|-------------------:|:----------:| -| OpenSoraT2V_v1_3-2B/122 | 8 | 1 | mins | 8 | 1x320x320 | zero2 | TRUE | FALSE | O0 | 5.1s | | -| OpenSoraT2V_v1_3-2B/122 | 8 | 2 | mins | 1 | up to 93x640x640 | zero2 + SP(sp_size=8) | TRUE | FALSE | O0 | | | -| OpenSoraT2V_v1_3-2B/122 | 8 | 3 | mins | 8 | 93x320x320 | zero2 | TRUE | FALSE | O0 | | | +| model name | cards | stage | batch size (global) | video size | Paramllelism |recompute |data sink | jit level| step time | train imgs/s | +|:----------------|:----------- |:---------:|:-----:|:----------:|:----------:|:----------:|:----------:|:----------:|-------------------:|:----------:| +| OpenSoraT2V_v1_3-2B/122 | 8 | 1 | 8 | 1x320x320 | zero2 | TRUE | FALSE | O0 | 5.1s | | +| OpenSoraT2V_v1_3-2B/122 | 8 | 2 | 1 | up to 93x640x640 | zero2 + SP(sp_size=8) | TRUE | FALSE | O0 | | | +| OpenSoraT2V_v1_3-2B/122 | 8 | 3 | 8 | 93x352x640 | zero2 | TRUE | FALSE | O0 | | | > SP: sequence parallelism. > From 33fbd4d97df8a5b5e4105fcf18ce023756d36eb2 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 16 Dec 2024 18:33:20 +0800 Subject: [PATCH 108/133] disable enable_tiling by default --- examples/opensora_pku/README.md | 9 ++++----- examples/opensora_pku/opensora/utils/sample_utils.py | 2 +- .../multi-devices/sample_t2v_93x640_ddp.sh | 1 - .../text_condition/multi-devices/sample_t2v_93x640_sp.sh | 1 - .../text_condition/multi-devices/train_t2i_stage1.sh | 7 +++---- .../text_condition/multi-devices/train_t2v_stage2.sh | 1 - .../text_condition/multi-devices/train_t2v_stage3.sh | 1 - .../{sample_t2i_1x320x320.sh => sample_t2i.sh} | 7 +++---- .../text_condition/single-device/sample_t2v_93x640.sh | 1 - .../single-device/sample_t2v_93x640_2texenc.sh | 1 - .../text_condition/single-device/train_t2i_stage1.sh | 7 +++---- .../text_condition/single-device/train_t2v_stage2.sh | 1 - .../text_condition/single-device/train_t2v_stage3.sh | 1 - 13 files changed, 14 insertions(+), 26 deletions(-) rename examples/opensora_pku/scripts/text_condition/single-device/{sample_t2i_1x320x320.sh => sample_t2i.sh} (87%) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index fc5e937e52..48086674bb 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -169,7 +169,6 @@ python opensora/sample/sample.py \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ - --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --seed 1234 \ @@ -400,7 +399,7 @@ As introduced in the [Open-Sora Plan Arxiv paper](https://arxiv.org/abs/2412.001 | Stage | Resolution | Num of frames | Datasets | Batch size | Train Steps | LR | Attention | |:--- |:--- |:--- |:--- |:--- |:--- |:--- |:--- | -| 1 (T2I) | 320x320 | 1 | SAM, AnyText, Human Images | 1024 | 150K (full-attention) + 100K (skiparse attention) | 2e-5 | Full 3D -> Skiparse | +| 1 (T2I) | 256x256 | 1 | SAM, AnyText, Human Images | 1024 | 150K (full-attention) + 100K (skiparse attention) | 2e-5 | Full 3D -> Skiparse | | 2 (T2I&T2V) | maximumly 93×640×640 | 49 | SAM, Panda70M | 1024 | 200K | 2e-5 | Skiparse | | 3 (T2V) | 93x352x640 | 49 | filtered Panda70M, high-quality data | 1024 | 100K~200K | 1e-5 | Skiparse | @@ -414,8 +413,8 @@ Here we choose an example of training scripts (`train_t2i_stage1.sh`) and explai Here is the major command of the training script: ```shell NUM_FRAME=1 -WIDTH=320 -HEIGHT=320 +WIDTH=256 +HEIGHT=256 python opensora/train/train_t2v_diffusers.py \ --data "scripts/train_data/image_data_v1_2.txt" \ --num_frames ${NUM_FRAME} \ @@ -465,7 +464,7 @@ We evaluated the training performance on Ascend NPUs. All experiments are runnin | model name | cards | stage | batch size (global) | video size | Paramllelism |recompute |data sink | jit level| step time | train imgs/s | |:----------------|:----------- |:---------:|:-----:|:----------:|:----------:|:----------:|:----------:|:----------:|-------------------:|:----------:| -| OpenSoraT2V_v1_3-2B/122 | 8 | 1 | 8 | 1x320x320 | zero2 | TRUE | FALSE | O0 | 5.1s | | +| OpenSoraT2V_v1_3-2B/122 | 8 | 1 | 8 | 1x256x256 | zero2 | TRUE | FALSE | O0 | 5.1s | | | OpenSoraT2V_v1_3-2B/122 | 8 | 2 | 1 | up to 93x640x640 | zero2 + SP(sp_size=8) | TRUE | FALSE | O0 | | | | OpenSoraT2V_v1_3-2B/122 | 8 | 3 | 8 | 93x352x640 | zero2 | TRUE | FALSE | O0 | | | diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index 53f42d8309..ea1cce0826 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -305,7 +305,7 @@ def prepare_pipeline(args): ) + (f"\nsp_size: {args.sp_size}" if args.sp_size != 1 else ""), f"Num of samples: {len(args.text_prompt)}", - f"Num params: {num_params:,} (latte: {num_params_latte:,}, vae: {num_params_vae:,})", + f"Num params: {num_params:,} (dit: {num_params_latte:,}, vae: {num_params_vae:,})", f"Num trainable params: {num_params_trainable:,}", f"Transformer dtype: {dtype}", f"VAE dtype: {vae_dtype}", diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh index c7ecbb55c7..a1c8730ecb 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_ddp.sh @@ -14,7 +14,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ - --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --seed 1234 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh index d5e0e1a5c3..375b95e29a 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/sample_t2v_93x640_sp.sh @@ -14,7 +14,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ - --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --seed 1234 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh index f346460583..811efe91c1 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh @@ -1,7 +1,7 @@ -# Stage 1: 1x320x320 +# Stage 1: 1x256x256 NUM_FRAME=1 -WIDTH=320 -HEIGHT=320 +WIDTH=256 +HEIGHT=256 ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ opensora/train/train_t2v_diffusers.py \ @@ -39,7 +39,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --rescale_betas_zero_snr \ --use_ema False \ --ema_start_step 0 \ - --enable_tiling \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index 463407cdee..cc0150734c 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -40,7 +40,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --rescale_betas_zero_snr \ --use_ema False \ --ema_start_step 0 \ - --enable_tiling \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index a857e3be8a..c16fa68c06 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -39,7 +39,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --rescale_betas_zero_snr \ --use_ema False \ --ema_start_step 0 \ - --enable_tiling \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i.sh similarity index 87% rename from examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh rename to examples/opensora_pku/scripts/text_condition/single-device/sample_t2i.sh index f21ac031ee..b0dbf4fe18 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i_1x320x320.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2i.sh @@ -3,11 +3,11 @@ export DEVICE_ID=0 python opensora/sample/sample.py \ - --model_path LanguageBind/Open-Sora-Plan-v1.3.0/1x320x320 \ + --model_path LanguageBind/Open-Sora-Plan-v1.3.0/1x256x256 \ --version v1_3 \ --num_frames 1 \ - --height 320 \ - --width 320 \ + --height 256 \ + --width 256 \ --text_encoder_name_1 google/mt5-xxl \ --text_prompt examples/prompt_list_human_images.txt \ --ae WFVAEModel_D8_4x8x8 \ @@ -16,7 +16,6 @@ python opensora/sample/sample.py \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ - --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --seed 1234 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh index e55d833bb9..4aee7dc258 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640.sh @@ -16,7 +16,6 @@ python opensora/sample/sample.py \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ - --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --seed 1234 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh index 54b3db8e30..db6634d786 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/sample_t2v_93x640_2texenc.sh @@ -17,7 +17,6 @@ python opensora/sample/sample.py \ --fps 18 \ --guidance_scale 7.5 \ --num_sampling_steps 100 \ - --enable_tiling \ --max_sequence_length 512 \ --sample_method EulerAncestralDiscrete \ --num_samples_per_prompt 1 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh index 9d0d38819b..7c3cb01e21 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh @@ -1,7 +1,7 @@ -# Stage 1: 1x320x320 +# Stage 1: 1x256x256 NUM_FRAME=1 -WIDTH=320 -HEIGHT=320 +WIDTH=256 +HEIGHT=256 python opensora/train/train_t2v_diffusers.py \ --model OpenSoraT2V_v1_3-2B/122 \ --text_encoder_name_1 google/mt5-xxl \ @@ -37,7 +37,6 @@ python opensora/train/train_t2v_diffusers.py \ --rescale_betas_zero_snr \ --use_ema False \ --ema_start_step 0 \ - --enable_tiling \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh index d22465b2aa..30c901ca9d 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh @@ -38,7 +38,6 @@ python opensora/train/train_t2v_diffusers.py \ --rescale_betas_zero_snr \ --use_ema False \ --ema_start_step 0 \ - --enable_tiling \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh index 7107bd6a8e..fa76ca18c1 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh @@ -37,7 +37,6 @@ python opensora/train/train_t2v_diffusers.py \ --rescale_betas_zero_snr \ --use_ema False \ --ema_start_step 0 \ - --enable_tiling \ --clip_grad True \ --max_grad_norm 1.0 \ --noise_offset 0.02 \ From 8d01c3ddea609c62f3b99fa3796f637553874ba8 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 17 Dec 2024 12:00:48 +0800 Subject: [PATCH 109/133] fix sp inference error --- .../opensora/sample/pipeline_opensora.py | 57 +++++-------------- 1 file changed, 14 insertions(+), 43 deletions(-) diff --git a/examples/opensora_pku/opensora/sample/pipeline_opensora.py b/examples/opensora_pku/opensora/sample/pipeline_opensora.py index 1fab66343b..162bbf59f7 100644 --- a/examples/opensora_pku/opensora/sample/pipeline_opensora.py +++ b/examples/opensora_pku/opensora/sample/pipeline_opensora.py @@ -447,25 +447,14 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents - def prepare_parallel_latent(self, video_states): + def prepare_parallel_input(self, input, sp_dim): sp_size = hccl_info.world_size index = hccl_info.rank % sp_size - padding_needed = (sp_size - video_states.shape[2] % sp_size) % sp_size - temp_attention_mask = None - if padding_needed > 0: - logger.debug("Doing video padding") - # B, C, T, H, W -> B, C, T', H, W - video_states = mint.nn.functional.pad( - video_states, (0, 0, 0, 0, 0, padding_needed), mode="constant", value=0 - ) - - b, _, f, h, w = video_states.shape - temp_attention_mask = mint.ones((b, f), ms.int32) - temp_attention_mask[:, -padding_needed:] = 0 - - assert video_states.shape[2] % sp_size == 0 - video_states = ops.chunk(video_states, sp_size, 2)[index] - return video_states, temp_attention_mask + assert ( + input.shape[sp_dim] % sp_size == 0 + ), f"Expect the parallel input length at dim={sp_dim} is divisble by the sp_size={sp_size}, but got {input.shape[sp_dim]}" + input = ops.chunk(input, sp_size, sp_dim)[index] + return input @property def guidance_scale(self): @@ -653,21 +642,10 @@ def __call__( # prompt_embeds = prompt_embeds.reshape(b, n, x, h).contiguous() # rank = hccl_info.rank # prompt_embeds = prompt_embeds[:, rank, :, :] - - latents, temp_attention_mask = self.prepare_parallel_latent(latents) - temp_attention_mask = ( - mint.cat([temp_attention_mask] * 2) - if (self.do_classifier_free_guidance and temp_attention_mask is not None) - else temp_attention_mask - ) # b (n x) h -> b n x h - prompt_embeds = prompt_embeds.reshape( - prompt_embeds.shape[0], world_size, prompt_embeds.shape[1] // world_size, -1 - ).contiguous() - index = hccl_info.rank % world_size - prompt_embeds = prompt_embeds[:, index, :, :] - else: - temp_attention_mask = None + prompt_embeds = self.prepare_parallel_input(prompt_embeds, sp_dim=1) + if prompt_embeds_2 is not None: + prompt_embeds_2 = self.prepare_parallel_input(prompt_embeds_2, sp_dim=1) # ==================make sp===================================== # 8. Denoising loop @@ -696,21 +674,14 @@ def __call__( if prompt_attention_mask.ndim == 2: prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l if prompt_embeds_2 is not None and prompt_embeds_2.ndim == 2: - prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d #OFFICIAL VER. DONT KNOW WHY + # prompt_embeds = prompt_embeds.unsqueeze(1) # b d -> b 1 d #OFFICIAL VER. DONT KNOW WHY # prompt_embeds_2 = prompt_embeds_2.unsqueeze(1) # + raise NotImplementedError - attention_mask = ops.ones_like(latent_model_input)[:, 0] - if temp_attention_mask is not None: - # temp_attention_mask shape (bs, t), 1 means to keep, 0 means to discard - # TODO: mask temporal padded tokens - attention_mask = ( - attention_mask.to(ms.int32) * temp_attention_mask[:, :, None, None].to(ms.int32) - ).to(ms.bool_) - # ==================prepare my shape===================================== - - # ==================make sp===================================== + attention_mask = ops.ones_like(latent_model_input)[:, 0].to(ms.int32) if get_sequence_parallel_state(): attention_mask = attention_mask.repeat(world_size, axis=1) + attention_mask = attention_mask.to(ms.bool_) # ==================make sp===================================== noise_pred = ops.stop_gradient( @@ -765,7 +736,7 @@ def __call__( # full_shape = [latents_shape[0] * world_size] + latents_shape[1:] # # b*sp c t//sp h w # all_latents = ops.zeros(full_shape, dtype=latents.dtype) all_latents = self.all_gather(latents) - latents_list = mint.chunk(all_latents, world_size, axis=0) + latents_list = mint.chunk(all_latents, world_size, 0) latents = mint.cat(latents_list, dim=2) # ==================make sp===================================== From 04b3087247642bdbac5ca73a43de880978ddbce3 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 17 Dec 2024 17:03:34 +0800 Subject: [PATCH 110/133] 1x256x256 exp script and performance update --- examples/opensora_pku/README.md | 4 ++-- .../scripts/text_condition/multi-devices/train_t2i_stage1.sh | 2 +- .../scripts/text_condition/single-device/train_t2i_stage1.sh | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 48086674bb..d5547daa1f 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -464,8 +464,8 @@ We evaluated the training performance on Ascend NPUs. All experiments are runnin | model name | cards | stage | batch size (global) | video size | Paramllelism |recompute |data sink | jit level| step time | train imgs/s | |:----------------|:----------- |:---------:|:-----:|:----------:|:----------:|:----------:|:----------:|:----------:|-------------------:|:----------:| -| OpenSoraT2V_v1_3-2B/122 | 8 | 1 | 8 | 1x256x256 | zero2 | TRUE | FALSE | O0 | 5.1s | | -| OpenSoraT2V_v1_3-2B/122 | 8 | 2 | 1 | up to 93x640x640 | zero2 + SP(sp_size=8) | TRUE | FALSE | O0 | | | +| OpenSoraT2V_v1_3-2B/122 | 8 | 1 | 32 | 1x256x256 | zero2 | TRUE | FALSE | O0 | 4.37 | 7.32 | +| OpenSoraT2V_v1_3-2B/122 | 8 | 2 | 1 | up to 93x640x640 | zero2 + SP(sp_size=8) | TRUE | FALSE | O0 | | | | OpenSoraT2V_v1_3-2B/122 | 8 | 3 | 8 | 93x352x640 | zero2 | TRUE | FALSE | O0 | | | > SP: sequence parallelism. diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh index 811efe91c1..630f3c36d6 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh @@ -21,7 +21,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ --gradient_checkpointing \ - --train_batch_size=1 \ + --train_batch_size=4 \ --dataloader_num_workers 8 \ --gradient_accumulation_steps=1 \ --max_train_steps=1000000 \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh index 7c3cb01e21..6a9921af1f 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh @@ -19,7 +19,7 @@ python opensora/train/train_t2v_diffusers.py \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ --gradient_checkpointing \ - --train_batch_size=1 \ + --train_batch_size=4 \ --dataloader_num_workers 8 \ --gradient_accumulation_steps=1 \ --max_train_steps=1000000 \ From 618d4d261f15feccdf67e2b7b00680469a9cc016 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 17 Dec 2024 17:22:35 +0800 Subject: [PATCH 111/133] update other stages performance table --- examples/opensora_pku/README.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index d5547daa1f..636fb84d52 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -465,14 +465,11 @@ We evaluated the training performance on Ascend NPUs. All experiments are runnin | model name | cards | stage | batch size (global) | video size | Paramllelism |recompute |data sink | jit level| step time | train imgs/s | |:----------------|:----------- |:---------:|:-----:|:----------:|:----------:|:----------:|:----------:|:----------:|-------------------:|:----------:| | OpenSoraT2V_v1_3-2B/122 | 8 | 1 | 32 | 1x256x256 | zero2 | TRUE | FALSE | O0 | 4.37 | 7.32 | -| OpenSoraT2V_v1_3-2B/122 | 8 | 2 | 1 | up to 93x640x640 | zero2 + SP(sp_size=8) | TRUE | FALSE | O0 | | | -| OpenSoraT2V_v1_3-2B/122 | 8 | 3 | 8 | 93x352x640 | zero2 | TRUE | FALSE | O0 | | | +| OpenSoraT2V_v1_3-2B/122 | 8 | 2 | 1 | up to 93x640x640 | zero2 + SP(sp_size=8) | TRUE | FALSE | O0 | 22.4s* | 4.15 | +| OpenSoraT2V_v1_3-2B/122 | 8 | 3 | 8 | 93x352x640 | zero2 | TRUE | FALSE | O0 | 10.71 | 69.47 | > SP: sequence parallelism. -> -> Stage means the muti-stage training as illustrated above. - -> batch size: the local batch size for a single card. +> *: dynamic resolution using bucket sampler. The step time may vary across different batches due to the varied resolutions. ## 👍 Acknowledgement * [Latte](https://github.com/Vchitect/Latte): The **main codebase** we built upon and it is an wonderful video generated model. From 54efbe3ef7bcfc92566220eabb26d22ead8477ac Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 17 Dec 2024 18:47:40 +0800 Subject: [PATCH 112/133] rewrite test_data --- examples/opensora_pku/README.md | 2 +- examples/opensora_pku/tests/test_data.py | 143 +++++++++++++++++++++-- examples/opensora_pku/tests/test_data.sh | 17 +-- 3 files changed, 141 insertions(+), 21 deletions(-) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 636fb84d52..18885880c8 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -159,7 +159,7 @@ python opensora/sample/sample.py \ --model_path LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640 \ --version v1_3 \ --num_frames 93 \ - --height 640 \ + --height 352 \ --width 640 \ --text_encoder_name_1 google/mt5-xxl \ --text_prompt examples/sora.txt \ diff --git a/examples/opensora_pku/tests/test_data.py b/examples/opensora_pku/tests/test_data.py index 3a9b5222dd..601bbf8de8 100644 --- a/examples/opensora_pku/tests/test_data.py +++ b/examples/opensora_pku/tests/test_data.py @@ -47,7 +47,7 @@ def load_dataset_and_dataloader(args, device_num=1, rank_id=0): assert args.dataset == "t2v", "Support t2v dataset only." print_banner("Dataset Loading") # Setup data: - train_dataset = getdataset(args) + train_dataset = getdataset(args, dataset_file=args.data) sampler = ( LengthGroupedBatchSampler( args.train_batch_size, @@ -92,14 +92,69 @@ def load_dataset_and_dataloader(args, device_num=1, rank_id=0): def parse_t2v_train_args(parser): + # TODO: NEW in v1.3 , but may not use + # dataset & dataloader + parser.add_argument("--max_hxw", type=int, default=None) + parser.add_argument("--min_hxw", type=int, default=None) + parser.add_argument("--ood_img_ratio", type=float, default=0.0) + parser.add_argument("--group_data", action="store_true") + parser.add_argument("--hw_stride", type=int, default=32) + parser.add_argument("--force_resolution", action="store_true") + parser.add_argument("--trained_data_global_step", type=int, default=None) + parser.add_argument("--use_decord", action="store_true") + + # text encoder & vae & diffusion model + parser.add_argument("--vae_fp32", action="store_true") + parser.add_argument("--extra_save_mem", action="store_true") + parser.add_argument("--text_encoder_name_1", type=str, default="DeepFloyd/t5-v1_1-xxl") + parser.add_argument("--text_encoder_name_2", type=str, default=None) + parser.add_argument("--sparse1d", action="store_true") + parser.add_argument("--sparse_n", type=int, default=2) + parser.add_argument("--skip_connection", action="store_true") + parser.add_argument("--cogvideox_scheduler", action="store_true") + parser.add_argument("--v1_5_scheduler", action="store_true") + parser.add_argument("--rf_scheduler", action="store_true") + parser.add_argument( + "--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"] + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + # diffusion setting + parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.") + parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.") + parser.add_argument("--rescale_betas_zero_snr", action="store_true") + + # validation & logs + parser.add_argument("--enable_profiling", action="store_true") + parser.add_argument("--num_sampling_steps", type=int, default=20) + parser.add_argument("--guidance_scale", type=float, default=4.5) + parser.add_argument("--output_dir", default="outputs/", help="The directory where training results are saved.") parser.add_argument("--dataset", type=str, default="t2v") - parser.add_argument("--data", type=str, required=True) + parser.add_argument( + "--data", + type=str, + required=True, + help="The training dataset text file specifying the path of video folder, text embedding cache folder, and the annotation json file", + ) + parser.add_argument( + "--val_data", + type=str, + default=None, + help="The validation dataset text file, same format as the training dataset text file.", + ) parser.add_argument("--cache_dir", type=str, default="./cache_dir") - parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="OpenSoraT2V-ROPE-L/122") - parser.add_argument("--cfg", type=float, default=0.1) - parser.add_argument("--ae", type=str, default="CausalVAEModel_4x8x8") - parser.add_argument("--ae_path", type=str, default="LanguageBind/Open-Sora-Plan-v1.1.0") parser.add_argument( "--filter_nonexistent", type=str2bool, @@ -113,12 +168,13 @@ def parse_t2v_train_args(parser): help="Whether to use T5 embedding cache. Must be provided in image/video_data.", ) parser.add_argument("--vae_latent_folder", default=None, type=str, help="root dir for the vae latent data") - + parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="OpenSoraT2V_v1_3-2B/122") parser.add_argument("--interpolation_scale_h", type=float, default=1.0) parser.add_argument("--interpolation_scale_w", type=float, default=1.0) parser.add_argument("--interpolation_scale_t", type=float, default=1.0) parser.add_argument("--downsampler", type=str, default=None) - + parser.add_argument("--ae", type=str, default="WFVAEModel_D8_4x8x8") + parser.add_argument("--ae_path", type=str, default="LanguageBind/Open-Sora-Plan-v1.3.0") parser.add_argument("--sample_rate", type=int, default=1) parser.add_argument("--train_fps", type=int, default=24) parser.add_argument("--drop_short_ratio", type=float, default=1.0) @@ -128,16 +184,17 @@ def parse_t2v_train_args(parser): parser.add_argument("--max_width", type=int, default=240) parser.add_argument("--group_frame", action="store_true") parser.add_argument("--group_resolution", action="store_true") + parser.add_argument("--use_rope", action="store_true") + parser.add_argument("--pretrained", type=str, default=None) parser.add_argument("--tile_overlap_factor", type=float, default=0.25) parser.add_argument("--enable_tiling", action="store_true") parser.add_argument("--attention_mode", type=str, choices=["xformers", "math", "flash"], default="xformers") - parser.add_argument("--text_encoder_name", type=str, default="DeepFloyd/t5-v1_1-xxl") - parser.add_argument("--model_max_length", type=int, default=300) + # parser.add_argument("--text_encoder_name", type=str, default="DeepFloyd/t5-v1_1-xxl") + parser.add_argument("--model_max_length", type=int, default=512) parser.add_argument("--multi_scale", action="store_true") - # parser.add_argument("--enable_tracker", action="store_true") parser.add_argument("--use_image_num", type=int, default=0) parser.add_argument("--use_img_from_vid", action="store_true") parser.add_argument( @@ -146,10 +203,72 @@ def parse_t2v_train_args(parser): default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) - + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument("--cfg", type=float, default=0.1) + parser.add_argument( + "--num_no_recompute", + type=int, + default=0, + help="If use_recompute is True, `num_no_recompute` blocks will be removed from the recomputation list." + "This is a positive integer which can be tuned based on the memory usage.", + ) parser.add_argument("--dataloader_prefetch_size", type=int, default=None, help="minddata prefetch size setting") parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel") parser.add_argument("--train_sp_batch_size", type=int, default=1, help="Batch size for sequence parallel training") + parser.add_argument( + "--vae_keep_gn_fp32", + default=False, + type=str2bool, + help="whether keep GroupNorm in fp32. Defaults to False in inference, better to set to True when training vae", + ) + parser.add_argument( + "--vae_precision", + default="fp16", + type=str, + choices=["bf16", "fp16"], + help="what data type to use for vae. Default is `fp16`, which corresponds to ms.float16", + ) + parser.add_argument( + "--text_encoder_precision", + default="bf16", + type=str, + choices=["bf16", "fp16"], + help="what data type to use for T5 text encoder. Default is `bf16`, which corresponds to ms.bfloat16", + ) + parser.add_argument( + "--enable_parallel_fusion", default=True, type=str2bool, help="Whether to parallel fusion for AdamW" + ) + parser.add_argument("--jit_level", default="O0", help="Set jit level: # O0: KBK, O1:DVM, O2: GE") + parser.add_argument("--noise_offset", type=float, default=0.02, help="The scale of noise offset.") + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. More details here: \ + https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. \ + If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument("--ema_start_step", type=int, default=0) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) return parser diff --git a/examples/opensora_pku/tests/test_data.sh b/examples/opensora_pku/tests/test_data.sh index ef2f2e9ff5..72feb24a23 100644 --- a/examples/opensora_pku/tests/test_data.sh +++ b/examples/opensora_pku/tests/test_data.sh @@ -1,20 +1,21 @@ python tests/test_data.py \ - --model OpenSoraT2V-ROPE-L/122 \ - --text_encoder_name google/mt5-xxl \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 google/mt5-xxl \ --dataset t2v \ - --num_frames 29 \ + --num_frames 93 \ --data "scripts/train_data/merge_data_mixkit.txt" \ --cache_dir "./" \ - --ae CausalVAEModel_D4_4x8x8 \ - --ae_path "LanguageBind/Open-Sora-Plan-v1.2.0/vae" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path "LanguageBind/Open-Sora-Plan-v1.3.0/vae" \ --sample_rate 1 \ - --max_height 480 \ + --max_height 352 \ --max_width 640 \ + --train_fps 16 \ + --force_resolution \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ - --attention_mode xformers \ --train_batch_size=8 \ --dataloader_num_workers 20 \ - --output_dir="t2i-image3d-1x480p/" \ + --output_dir="test_data/" \ --model_max_length 512 \ From a97395bec778ab5c1272425a481362ea22e299eb Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 17 Dec 2024 19:12:36 +0800 Subject: [PATCH 113/133] remove _backbone from dit ckpt --- .../opensora/models/diffusion/opensora/modeling_opensora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 573d8e7106..8b37c33161 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -298,6 +298,9 @@ def load_from_safetensors(cls, model, ckpt_path): model_file = ckpt_path pretrained_model_name_or_path = os.path.dirname(ckpt_path) state_dict = load_state_dict(model_file, variant=None) + state_dict = dict( + [k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in state_dict.items() + ) # remove _backbone model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( @@ -328,6 +331,9 @@ def load_from_ms_checkpoint(self, model, ckpt_path): new_pname = pname.replace(pre, "") sd[new_pname] = sd.pop(pname) + sd = dict( + [k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in sd.items() + ) # remove _backbone m, u = ms.load_param_into_net(model, sd) print("net param not load: ", m, len(m)) print("ckpt param not load: ", u, len(u)) From 7af2562b3f43ab1a8db0a9750b0b1a1d1a02a3b2 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:28:24 +0800 Subject: [PATCH 114/133] update demo --- examples/opensora_pku/README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 18885880c8..d0bf16801b 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -25,6 +25,32 @@ Here we provide an efficient MindSpore version of [Open-Sora-Plan](https://githu | :---: | :---: | :---: | :---: | | 2.3.1 | 24.1RC2 |7.3.0.1.231| 8.0.RC2.beta1 | +## 🎥 Demo + +The following videos are generated based on MindSpore and Ascend 910*. + +Open-Sora-Plan v1.3.0 Demo + +93x352x640 Text-to-Video Generation. + +| 93x352x640 (5.8s) | +| --- | +| | +| A litter of golden retriever puppies playing in the snow... | + +| 93x352x640 (5.8s) | +| --- | +| | +| An extreme close-up of an gray-haired man with a beard in his 60s... | + +| 93x352x640 (5.8s) | +| --- | +| | +| Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach... | + +Videos are saved to `.gif` for display. + + ## 🔆 Features - 📍 **Open-Sora-Plan v1.3.0** with the following features @@ -469,6 +495,7 @@ We evaluated the training performance on Ascend NPUs. All experiments are runnin | OpenSoraT2V_v1_3-2B/122 | 8 | 3 | 8 | 93x352x640 | zero2 | TRUE | FALSE | O0 | 10.71 | 69.47 | > SP: sequence parallelism. + > *: dynamic resolution using bucket sampler. The step time may vary across different batches due to the varied resolutions. ## 👍 Acknowledgement From 0ffab7fd1e93969ae6e76ee5469d31aa67f4c0ed Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:58:41 +0800 Subject: [PATCH 115/133] fix sp inference error --- .../opensora/models/diffusion/opensora/modules.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index 7d9bf6dc3c..5e75e62504 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -175,7 +175,7 @@ def __init__( self.alltoall_sbh_q = AllToAll_SBH(scatter_dim=1, gather_dim=0) self.alltoall_sbh_k = AllToAll_SBH(scatter_dim=1, gather_dim=0) self.alltoall_sbh_v = AllToAll_SBH(scatter_dim=1, gather_dim=0) - self.alltoall_sbh_out = AllToAll_SBH(scatter_dim=0, gather_dim=1) + self.alltoall_sbh_out = AllToAll_SBH(scatter_dim=1, gather_dim=0) else: self.sp_size = 1 self.alltoall_sbh_q = None @@ -362,11 +362,11 @@ def __call__( hidden_states = self._reverse_sparse_1d(hidden_states, total_frame, height, width, pad_len) hidden_states = hidden_states.swapaxes(0, 1) # SBH -> BSH - # [s, b, h // sp * d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] if get_sequence_parallel_state(): - hidden_states = self.alltoall_sbh_out(hidden_states.reshape(-1, FA_head_num, head_dim)) - hidden_states = hidden_states.view(-1, batch_size, inner_dim) - + # [b, s * sp, h // sp, d] -> [h // sp, s * sp, b , d] + hidden_states = hidden_states.view(batch_size, -1, FA_head_num, head_dim).transpose(2, 1, 0, 3) + # [h // sp, s * sp, b , d] -> [h, s, b , d] -> [s, b, h, d] -> [s, b, h*d] + hidden_states = self.alltoall_sbh_out(hidden_states).transpose(1, 2, 0, 3).view(-1, batch_size, inner_dim) hidden_states = hidden_states.to(query.dtype) # linear proj From f09d6105fd8d5d5e9a6981cbfdd311a28532b22e Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 20 Dec 2024 10:04:16 +0800 Subject: [PATCH 116/133] make block input as SBH --- .../opensora/models/diffusion/common.py | 12 +-- .../diffusion/opensora/modeling_opensora.py | 36 +++----- .../models/diffusion/opensora/modules.py | 83 ++++++------------- .../diffusion/opensora/net_with_loss.py | 4 - 4 files changed, 39 insertions(+), 96 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/common.py b/examples/opensora_pku/opensora/models/diffusion/common.py index 799c49b8ea..5a708ee052 100644 --- a/examples/opensora_pku/opensora/models/diffusion/common.py +++ b/examples/opensora_pku/opensora/models/diffusion/common.py @@ -1,7 +1,6 @@ import itertools import numpy as np -from opensora.acceleration.parallel_states import get_sequence_parallel_state import mindspore as ms from mindspore import mint, nn, ops @@ -54,10 +53,7 @@ def __call__(self, b, t, h, w): z = list(range(t)) pos = list(itertools.product(z, y, x)) pos = ms.Tensor(pos) - if get_sequence_parallel_state(): - pos = pos.reshape(t * h * w, 3).swapaxes(0, 1).reshape(3, -1, 1).broadcast_to((3, -1, b)) - else: - pos = pos.reshape(t * h * w, 3).swapaxes(0, 1).reshape(3, 1, -1).broadcast_to((3, b, -1)) + pos = pos.reshape(t * h * w, 3).swapaxes(0, 1).reshape(3, -1, 1).broadcast_to((3, -1, b)) poses = (pos[0], pos[1], pos[2]) max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) @@ -97,12 +93,6 @@ def rotate_half(x): def apply_rope1d(self, tokens, pos1d, cos, sin): assert pos1d.ndim == 2 - # if npu_config is None and not get_sequence_parallel_state(): - # # for (batch_size x nheads x ntokens x dim) - # cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] - # sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] - # else: - # for (batch_size x ntokens x nheads x dim) cos = cos[pos1d.to(ms.int32)][:, :, None, :] sin = sin[pos1d.to(ms.int32)][:, :, None, :] diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 8b37c33161..7a0ba7bffd 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -3,7 +3,6 @@ import os from typing import Optional -from opensora.acceleration.parallel_states import get_sequence_parallel_state from opensora.models.diffusion.common import PatchEmbed2D from opensora.models.diffusion.opensora.modules import Attention, BasicTransformerBlock, LayerNorm from opensora.npu_config import npu_config @@ -399,28 +398,20 @@ def construct( hidden_states, encoder_hidden_states, timestep, batch_size, frame ) - if get_sequence_parallel_state(): - # To - # x (t*h*w b d) or (t//sp*h*w b d) - # cond_1 (l b d) or (l//sp b d) - # b s h -> s b h - hidden_states = hidden_states.swapaxes(0, 1).contiguous() - # b s h -> s b h - encoder_hidden_states = encoder_hidden_states.swapaxes(0, 1).contiguous() - timestep = timestep.view(batch_size, 6, -1).swapaxes(0, 1).contiguous() - - # if npu_config is None: - # if get_sequence_parallel_state(): - # head_num = self.config.num_attention_heads // hccl_info.world_size - # else: - # head_num = self.config.num_attention_heads - # else: - head_num = None + # To + # x (t*h*w b d) or (t//sp*h*w b d) + # cond_1 (l b d) or (l//sp b d) + # b s h -> s b h + hidden_states = hidden_states.swapaxes(0, 1).contiguous() + # b s h -> s b h + encoder_hidden_states = encoder_hidden_states.swapaxes(0, 1).contiguous() + timestep = timestep.view(batch_size, 6, -1).swapaxes(0, 1).contiguous() + sparse_mask = {} if self.sparse1d: for sparse_n in [1, self.sparse_n]: sparse_mask[sparse_n] = Attention.prepare_sparse_mask( - attention_mask, encoder_attention_mask, sparse_n, head_num + attention_mask, encoder_attention_mask, sparse_n, head_num=None ) # 2. Blocks @@ -444,10 +435,9 @@ def construct( width=width, ) # BSH - if get_sequence_parallel_state(): - # To (b, t*h*w, h) or (b, t//sp*h*w, h) - # s b h -> b s h - hidden_states = hidden_states.swapaxes(0, 1).contiguous() + # To (b, t*h*w, h) or (b, t//sp*h*w, h) + # s b h -> b s h + hidden_states = hidden_states.swapaxes(0, 1).contiguous() # 3. Output output = self._get_output_for_patched_inputs( diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index 5e75e62504..422fa11d69 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -175,7 +175,7 @@ def __init__( self.alltoall_sbh_q = AllToAll_SBH(scatter_dim=1, gather_dim=0) self.alltoall_sbh_k = AllToAll_SBH(scatter_dim=1, gather_dim=0) self.alltoall_sbh_v = AllToAll_SBH(scatter_dim=1, gather_dim=0) - self.alltoall_sbh_out = AllToAll_SBH(scatter_dim=1, gather_dim=0) + self.alltoall_sbh_out = AllToAll_SBH(scatter_dim=0, gather_dim=1) else: self.sp_size = 1 self.alltoall_sbh_q = None @@ -268,15 +268,9 @@ def __call__( width: int = 16, ) -> ms.Tensor: # residual = hidden_states - - if get_sequence_parallel_state(): - sequence_length, batch_size, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - else: - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) # BSH + sequence_length, batch_size, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) # attention_mask shape if attention_mask.ndim == 3: @@ -305,40 +299,21 @@ def __call__( key = self.alltoall_sbh_k(key.view(-1, attn.heads, head_dim)) value = self.alltoall_sbh_v(value.view(-1, attn.heads, head_dim)) - # print(f'batch: {batch_size}, FA_head_num: {FA_head_num}, head_dim: {head_dim}, total_frame:{total_frame}') - query = query.view(-1, batch_size, FA_head_num, head_dim) # BUG? TODO: to test - key = key.view(-1, batch_size, FA_head_num, head_dim) # BUG ? - - # print(f'q {query.shape}, k {key.shape}, v {value.shape}') - if not self.is_cross_attn: - # require the shape of (ntokens x batch_size x nheads x dim) - pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width) - # print(f'pos_thw {pos_thw}') - query = self.rope(query, pos_thw) - key = self.rope(key, pos_thw) - - query = query.view(-1, batch_size, FA_head_num * head_dim) - key = key.view(-1, batch_size, FA_head_num * head_dim) - value = value.view(-1, batch_size, FA_head_num * head_dim) - else: - # print(f'batch: {batch_size}, FA_head_num: {FA_head_num}, head_dim: {head_dim}, total_frame:{total_frame}') - query = query.view(batch_size, -1, FA_head_num, head_dim) - key = key.view(batch_size, -1, FA_head_num, head_dim) - # (batch_size x ntokens x nheads x dim) - - # print(f'q {query.shape}, k {key.shape}, v {value.shape}') - if not self.is_cross_attn: - # require the shape of (batch_size x ntokens x nheads x dim) - pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width) - # print(f'pos_thw {pos_thw}') - query = self.rope(query, pos_thw) - key = self.rope(key, pos_thw) + # print(f'batch: {batch_size}, FA_head_num: {FA_head_num}, head_dim: {head_dim}, total_frame:{total_frame}') + query = query.view(-1, batch_size, FA_head_num, head_dim) + key = key.view(-1, batch_size, FA_head_num, head_dim) - query = query.view(batch_size, -1, FA_head_num * head_dim).swapaxes(0, 1) - key = key.view(batch_size, -1, FA_head_num * head_dim).swapaxes(0, 1) - value = value.swapaxes(0, 1) + # print(f'q {query.shape}, k {key.shape}, v {value.shape}') + if not self.is_cross_attn: + # require the shape of (ntokens x batch_size x nheads x dim) or (batch_size x ntokens x nheads x dim) + pos_thw = self.position_getter(batch_size, t=total_frame, h=height, w=width) + # print(f'pos_thw {pos_thw}') + query = self.rope(query, pos_thw) + key = self.rope(key, pos_thw) - # print(f'q {query.shape}, k {key.shape}, v {value.shape}') #(SBH) + query = query.view(-1, batch_size, FA_head_num * head_dim) + key = key.view(-1, batch_size, FA_head_num * head_dim) + value = value.view(-1, batch_size, FA_head_num * head_dim) if self.sparse1d: query, pad_len = self._sparse_1d(query, total_frame, height, width) @@ -356,17 +331,15 @@ def __call__( hidden_states = npu_config.run_attention( query, key, value, attention_mask, input_layout="BSH", head_dim=head_dim, head_num=FA_head_num ) + hidden_states = hidden_states.swapaxes(0, 1) # BSH -> SBH if self.sparse1d: - hidden_states = hidden_states.swapaxes(0, 1) # BSH -> SBH hidden_states = self._reverse_sparse_1d(hidden_states, total_frame, height, width, pad_len) - hidden_states = hidden_states.swapaxes(0, 1) # SBH -> BSH + # [s, b, h // sp * d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] if get_sequence_parallel_state(): - # [b, s * sp, h // sp, d] -> [h // sp, s * sp, b , d] - hidden_states = hidden_states.view(batch_size, -1, FA_head_num, head_dim).transpose(2, 1, 0, 3) - # [h // sp, s * sp, b , d] -> [h, s, b , d] -> [s, b, h, d] -> [s, b, h*d] - hidden_states = self.alltoall_sbh_out(hidden_states).transpose(1, 2, 0, 3).view(-1, batch_size, inner_dim) + hidden_states = self.alltoall_sbh_out(hidden_states.reshape(-1, FA_head_num, head_dim)) + hidden_states = hidden_states.view(-1, batch_size, inner_dim) hidden_states = hidden_states.to(query.dtype) # linear proj @@ -478,16 +451,10 @@ def construct( width: int, ) -> ms.Tensor: # 0. Self-Attention - if get_sequence_parallel_state(): - batch_size = hidden_states.shape[1] - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( - self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1), 6, dim=0 - ) - else: - batch_size = hidden_states.shape[0] - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1), 6, dim=1 - ) + batch_size = hidden_states.shape[1] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mint.chunk( + self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1), 6, dim=0 + ) norm_hidden_states = self.norm1(hidden_states) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py index 191ee64565..b44c2959b0 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/net_with_loss.py @@ -120,10 +120,6 @@ def __init__( self.text_encoder = text_encoder self.use_image_num = use_image_num - - # FIXME: bug when sp_size=2 - # self.broadcast_t = None if not get_sequence_parallel_state() \ - # else ops.Broadcast(root_rank=int(hccl_info.group_id * hccl_info.world_size), group=hccl_info.group) self.reduce_t = None if not get_sequence_parallel_state() else ops.AllReduce(group=hccl_info.group) self.sp_size = 1 if not get_sequence_parallel_state() else hccl_info.world_size self.all_gather = None if not get_sequence_parallel_state() else ops.AllGather(group=hccl_info.group) From 9c72d71b2278d5a08d89988c008fe96c6e8276e3 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:30:50 +0800 Subject: [PATCH 117/133] update new dataset v0.3.0 --- .../opensora_pku/opensora/dataset/__init__.py | 46 +- .../opensora_pku/opensora/dataset/loader.py | 36 +- .../opensora/dataset/t2v_datasets.py | 796 ++++++++++++------ .../opensora/dataset/transform.py | 230 ++++- .../opensora/train/train_t2v_diffusers.py | 66 +- .../opensora/utils/dataset_utils.py | 380 ++++++--- .../scripts/train_data/video_data_v1_2.txt | 2 +- examples/opensora_pku/tests/test_data.py | 48 +- examples/opensora_pku/tests/test_data.sh | 11 +- 9 files changed, 1103 insertions(+), 512 deletions(-) diff --git a/examples/opensora_pku/opensora/dataset/__init__.py b/examples/opensora_pku/opensora/dataset/__init__.py index d596a1d9a9..8c38d20fa6 100644 --- a/examples/opensora_pku/opensora/dataset/__init__.py +++ b/examples/opensora_pku/opensora/dataset/__init__.py @@ -6,7 +6,7 @@ from transformers import AutoTokenizer from .t2v_datasets import T2V_dataset -from .transform import TemporalRandomCrop, center_crop_th_tw, spatial_stride_crop_video, maxhxw_resize +from .transform import TemporalRandomCrop, center_crop_th_tw, maxhxw_resize, spatial_stride_crop_video def getdataset(args, dataset_file): @@ -18,16 +18,11 @@ def norm_func_albumentation(image, **kwargs): mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} targets = {"image{}".format(i): "image" for i in range(args.num_frames)} - resize_topcrop = [ - Lambda( - name="crop_topcrop", - image=partial(center_crop_th_tw, th=args.max_height, tw=args.max_width, top_crop=True), - p=1.0, - ), - Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]), - ] + if args.force_resolution: - assert (args.max_height is not None) and (args.max_width is not None), "set max_height and max_width for fixed resolution" + assert (args.max_height is not None) and ( + args.max_width is not None + ), "set max_height and max_width for fixed resolution" resize = [ Lambda( name="crop_centercrop", @@ -36,7 +31,7 @@ def norm_func_albumentation(image, **kwargs): ), Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]), ] - else: # dynamic resolution + else: # dynamic resolution assert args.max_hxw is not None, "set max_hxw for dynamic resolution" resize = [ Lambda( @@ -46,7 +41,7 @@ def norm_func_albumentation(image, **kwargs): ), Lambda( name="spatial_stride_crop", - image=partial(spatial_stride_crop_video, stride=args.hw_stride), # default stride=32 + image=partial(spatial_stride_crop_video, stride=args.hw_stride), # default stride=32 p=1.0, ), ] @@ -55,35 +50,20 @@ def norm_func_albumentation(image, **kwargs): [*resize, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)], additional_targets=targets, ) - transform_topcrop = Compose( - [*resize_topcrop, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)], - additional_targets=targets, - ) - tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir) + tokenizer_1 = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir) + tokenizer_2 = None if args.text_encoder_name_2 is not None: tokenizer_2 = AutoTokenizer.from_pretrained(args.text_encoder_name_2, cache_dir=args.cache_dir) if args.dataset == "t2v": return T2V_dataset( - dataset_file, - num_frames=args.num_frames, - train_fps=args.train_fps, - use_image_num=args.use_image_num, - use_img_from_vid=args.use_img_from_vid, - model_max_length=args.model_max_length, - cfg=args.cfg, - speed_factor=args.speed_factor, - max_height=args.max_height, - max_width=args.max_width, - drop_short_ratio=args.drop_short_ratio, - dataloader_num_workers=args.dataloader_num_workers, - text_encoder_name=args.text_encoder_name_1, # TODO: update with 2nd text encoder - return_text_emb=args.text_embed_cache, + args, transform=transform, temporal_sample=temporal_sample, - tokenizer=tokenizer, - transform_topcrop=transform_topcrop, + tokenizer_1=tokenizer_1, + tokenizer_2=tokenizer_2, + return_text_emb=args.text_embed_cache, ) elif args.dataset == "inpaint" or args.dataset == "i2v": raise NotImplementedError diff --git a/examples/opensora_pku/opensora/dataset/loader.py b/examples/opensora_pku/opensora/dataset/loader.py index 832dbd55fa..d53372a0dc 100644 --- a/examples/opensora_pku/opensora/dataset/loader.py +++ b/examples/opensora_pku/opensora/dataset/loader.py @@ -26,6 +26,7 @@ def create_dataloader( enable_modelarts=False, collate_fn=None, sampler=None, + batch_sampler=None, ): datalen = len(dataset) @@ -46,6 +47,7 @@ def create_dataloader( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + batch_sampler=batch_sampler, ) dl = GeneratorDataset( loader, @@ -62,13 +64,24 @@ def create_dataloader( def build_dataloader( - dataset, datalens, collate_fn, batch_size, device_num, rank_id=0, sampler=None, shuffle=True, drop_last=True + dataset, + datalens, + collate_fn, + batch_size, + device_num, + rank_id=0, + sampler=None, + batch_sampler=None, + shuffle=True, + drop_last=True, ): - if sampler is None: - sampler = BatchSampler(datalens, batch_size=batch_size, device_num=device_num, shuffle=shuffle) + if batch_sampler is None: + batch_sampler = BatchSampler(datalens, batch_size=batch_size, device_num=device_num, shuffle=shuffle) loader = DataLoader( dataset, - batch_sampler=sampler, + batch_size=batch_size, + sampler=sampler, + batch_sampler=batch_sampler, collate_fn=collate_fn, device_num=device_num, drop_last=drop_last, @@ -107,14 +120,25 @@ def __len__(self): class DataLoader: """DataLoader""" - def __init__(self, dataset, batch_sampler, collate_fn, device_num=1, drop_last=True, rank_id=0): + def __init__( + self, + dataset, + batch_size, + sampler=None, + batch_sampler=None, + collate_fn=None, + device_num=1, + drop_last=True, + rank_id=0, + ): self.dataset = dataset + self.sampler = sampler self.batch_sampler = batch_sampler self.collat_fn = collate_fn self.device_num = device_num self.rank_id = rank_id self.drop_last = drop_last - self.batch_size = len(next(iter(self.batch_sampler))) + self.batch_size = batch_size def __iter__(self): self.step_index = 0 diff --git a/examples/opensora_pku/opensora/dataset/t2v_datasets.py b/examples/opensora_pku/opensora/dataset/t2v_datasets.py index 2909d361b0..013d4be5b1 100644 --- a/examples/opensora_pku/opensora/dataset/t2v_datasets.py +++ b/examples/opensora_pku/opensora/dataset/t2v_datasets.py @@ -3,15 +3,27 @@ import logging import math import os +import pickle import random +import time from collections import Counter +from concurrent.futures import ThreadPoolExecutor from os.path import join as opj from pathlib import Path +import cv2 +import decord import numpy as np -from opensora.utils.dataset_utils import DecordInit +from opensora.dataset.transform import ( + add_aesthetic_notice_image, + add_aesthetic_notice_video, + calculate_statistics, + get_params, + maxhwresize, +) from opensora.utils.utils import text_preprocessing from PIL import Image +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -84,10 +96,40 @@ def get_item(self, work_info): dataset_prog = DataSetProg() -def find_closest_y(x, vae_stride_t=4, model_ds_t=4): - if x < 29: +class DecordDecoder(object): + def __init__(self, url, num_threads=1): + self.num_threads = num_threads + self.ctx = decord.cpu(0) + self.reader = decord.VideoReader(url, ctx=self.ctx, num_threads=self.num_threads) + + def get_avg_fps(self): + return self.reader.get_avg_fps() if self.reader.get_avg_fps() > 0 else 30.0 + + def get_num_frames(self): + return len(self.reader) + + def get_height(self): + return self.reader[0].shape[0] if self.get_num_frames() > 0 else 0 + + def get_width(self): + return self.reader[0].shape[1] if self.get_num_frames() > 0 else 0 + + # output shape [T, H, W, C] + def get_batch(self, frame_indices): + try: + # frame_indices[0] = 1000 + video_data = self.reader.get_batch(frame_indices).asnumpy() + return video_data + except Exception as e: + print("get_batch execption:", e) + return None + + +def find_closest_y(x, vae_stride_t=4, model_ds_t=1): + min_num_frames = 29 + if x < min_num_frames: return -1 - for y in range(x, 12, -1): + for y in range(x, min_num_frames - 1, -1): if (y - 1) % vae_stride_t == 0 and ((y - 1) // vae_stride_t + 1) % model_ds_t == 0: return y return -1 @@ -102,79 +144,334 @@ def filter_resolution(h, w, max_h_div_w_ratio=17 / 16, min_h_div_w_ratio=8 / 16) class T2V_dataset: def __init__( self, - data, - num_frames: int = 29, - train_fps: int = 24, - use_image_num: int = 0, - use_img_from_vid: bool = False, - model_max_length: int = 512, - cfg: float = 0.1, - speed_factor: float = 1.0, - max_height: int = 480, - max_width: int = 640, - drop_short_ratio: float = 1.0, - dataloader_num_workers: int = 10, - text_encoder_name: str = "google/mt5-xxl", - transform=None, - temporal_sample=None, - tokenizer=None, - transform_topcrop=None, + args, + transform, + temporal_sample, + tokenizer_1, + tokenizer_2, filter_nonexistent=True, return_text_emb=False, ): - self.data = data - self.num_frames = num_frames - self.train_fps = train_fps - self.use_image_num = use_image_num - self.use_img_from_vid = use_img_from_vid + self.data = args.data + self.num_frames = args.num_frames + self.train_fps = args.train_fps self.transform = transform - self.transform_topcrop = transform_topcrop self.temporal_sample = temporal_sample - self.tokenizer = tokenizer - self.model_max_length = model_max_length - self.cfg = cfg - self.speed_factor = speed_factor - self.max_height = max_height - self.max_width = max_width - self.drop_short_ratio = drop_short_ratio + self.tokenizer_1 = tokenizer_1 + self.tokenizer_2 = tokenizer_2 + self.model_max_length = args.model_max_length + self.cfg = args.cfg + self.speed_factor = args.speed_factor + self.max_height = args.max_height + self.max_width = args.max_width + self.drop_short_ratio = args.drop_short_ratio + self.hw_stride = args.hw_stride + self.force_resolution = args.force_resolution + self.max_hxw = args.max_hxw + self.min_hxw = args.min_hxw + self.sp_size = args.sp_size assert self.speed_factor >= 1 - self.v_decoder = DecordInit() + self.video_reader = "decord" if args.use_decord else "opencv" + self.ae_stride_t = args.ae_stride_t + self.total_batch_size = args.total_batch_size + self.seed = args.seed + self.generator = np.random.default_rng(self.seed) + self.hw_aspect_thr = 2.0 # just a threshold + self.too_long_factor = 10.0 # set this threshold larger for longer video datasets self.filter_nonexistent = filter_nonexistent self.return_text_emb = return_text_emb if self.return_text_emb and self.cfg > 0: logger.warning(f"random text drop ratio {self.cfg} will be ignored when text embeddings are cached.") self.duration_threshold = 100.0 - self.support_Chinese = True - if not ("mt5" in text_encoder_name): - self.support_Chinese = False + self.support_Chinese = False + if "mt5" in args.text_encoder_name_1: + self.support_Chinese = True + if args.text_encoder_name_2 is not None and "mt5" in args.text_encoder_name_2: + self.support_Chinese = True + s = time.time() + cap_list, self.sample_size, self.shape_idx_dict = self.define_frame_index(self.data) + e = time.time() + print(f"Build data time: {e-s}") + self.lengths = self.sample_size - cap_list = self.get_cap_list() - if self.filter_nonexistent: - cap_list = self.filter_nonexistent_files(cap_list) + n_elements = len(cap_list) + dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list, n_elements) + print(f"Data length: {len(dataset_prog.cap_list)}") + self.executor = ThreadPoolExecutor(max_workers=1) + self.timeout = 60 - assert len(cap_list) > 0 - cap_list, self.sample_num_frames = self.define_frame_index(cap_list) - self.lengths = self.sample_num_frames + def define_frame_index(self, data): + shape_idx_dict = {} + new_cap_list = [] + sample_size = [] + aesthetic_score = [] + cnt_vid = 0 + cnt_img = 0 + cnt_too_long = 0 + cnt_too_short = 0 + cnt_no_cap = 0 + cnt_no_resolution = 0 + cnt_no_aesthetic = 0 + cnt_img_res_mismatch_stride = 0 + cnt_vid_res_mismatch_stride = 0 + cnt_img_aspect_mismatch = 0 + cnt_vid_aspect_mismatch = 0 + cnt_img_res_too_small = 0 + cnt_vid_res_too_small = 0 + cnt_vid_after_filter = 0 + cnt_img_after_filter = 0 + cnt_no_existent = 0 + cnt = 0 - n_elements = len(cap_list) - dataset_prog.set_cap_list(dataloader_num_workers, cap_list, n_elements) - - print(f"video length: {len(dataset_prog.cap_list)}", flush=True) - - def filter_nonexistent_files(self, cap_list): - indexes_to_remove = [] - for i, item in enumerate(cap_list): - path = item["path"] - if not os.path.exists(path): - second_path = path.replace("_resize1080p.mp4", ".mp4") - if os.path.exists(second_path): - cap_list[i]["path"] = second_path + with open(data, "r") as f: + folder_anno = [i.strip().split(",") for i in f.readlines() if len(i.strip()) > 0] + assert len(folder_anno) > 0, "input dataset file cannot be empty!" + for input_dataset in tqdm(folder_anno): + text_embed_folder_1, text_embed_folder_2 = None, None + if len(input_dataset) == 2: + assert not self.return_text_emb, "Train without text embedding cache!" + elif len(input_dataset) == 3: + text_embed_folder_1 = input_dataset[1] + sub_root, anno = input_dataset[0], input_dataset[-1] + elif len(input_dataset) == 4: + text_embed_folder_1 = input_dataset[1] + text_embed_folder_2 = input_dataset[2] + sub_root, anno = input_dataset[0], input_dataset[-1] + else: + raise ValueError("Not supported input dataset file!") + + print(f"Building {anno}...") + if anno.endswith(".json"): + with open(anno, "r") as f: + sub_list = json.load(f) + elif anno.endswith(".pkl"): + with open(anno, "rb") as f: + sub_list = pickle.load(f) + for index, i in enumerate(tqdm(sub_list)): + cnt += 1 + path = os.path.join(sub_root, i["path"]) + if self.filter_nonexistent: + if not os.path.exists(path): + cnt_no_existent += 1 + continue + + if self.return_text_emb: + text_embeds_paths = self.get_text_embed_file_path(i) + if text_embed_folder_1 is not None: + i["text_embed_path_1"] = [opj(text_embed_folder_1, tp) for tp in text_embeds_paths] + if any([not os.path.exists(p) for p in i["text_embed_path_1"]]): + cnt_no_existent += 1 + continue + if text_embed_folder_2 is not None: + i["text_embed_path_2"] = [opj(text_embed_folder_2, tp) for tp in text_embeds_paths] + if any([not os.path.exists(p) for p in i["text_embed_path_2"]]): + cnt_no_existent += 1 + continue + + if path.endswith(".mp4"): + cnt_vid += 1 + elif path.endswith(".jpg"): + cnt_img += 1 + + # ======no aesthetic===== + if i.get("aesthetic", None) is None or i.get("aes", None) is None: + cnt_no_aesthetic += 1 + else: + aesthetic_score.append(i.get("aesthetic", None) or i.get("aes", None)) + + # ======no caption===== + cap = i.get("cap", None) + if cap is None: + cnt_no_cap += 1 + continue + + # ======resolution mismatch===== + i["path"] = path + assert ( + "resolution" in i + ), "Expect that each element in the provided datset should have a item named `resolution`" + if i.get("resolution", None) is None: + cnt_no_resolution += 1 + continue else: - indexes_to_remove.append(i) - cap_list = [item for i, item in enumerate(cap_list) if i not in indexes_to_remove] - logger.info(f"Nonexistent files: {len(indexes_to_remove)}") - return cap_list + assert ( + "height" in i["resolution"] and "width" in i["resolution"] + ), "Expect that each element has `resolution: \\{'height': int, 'width': int,\\}`" + if i["resolution"].get("height", None) is None or i["resolution"].get("width", None) is None: + cnt_no_resolution += 1 + continue + else: + height, width = i["resolution"]["height"], i["resolution"]["width"] + if not self.force_resolution: + if height <= 0 or width <= 0: + cnt_no_resolution += 1 + continue + + tr_h, tr_w = maxhwresize(height, width, self.max_hxw) + _, _, sample_h, sample_w = get_params(tr_h, tr_w, self.hw_stride) + + if sample_h <= 0 or sample_w <= 0: + if path.endswith(".mp4"): + cnt_vid_res_mismatch_stride += 1 + elif path.endswith(".jpg"): + cnt_img_res_mismatch_stride += 1 + continue + + # filter min_hxw + if sample_h * sample_w < self.min_hxw: + if path.endswith(".mp4"): + cnt_vid_res_too_small += 1 + elif path.endswith(".jpg"): + cnt_img_res_too_small += 1 + continue + + # filter aspect + is_pick = filter_resolution( + sample_h, + sample_w, + max_h_div_w_ratio=self.hw_aspect_thr, + min_h_div_w_ratio=1 / self.hw_aspect_thr, + ) + if not is_pick: + if path.endswith(".mp4"): + cnt_vid_aspect_mismatch += 1 + elif path.endswith(".jpg"): + cnt_img_aspect_mismatch += 1 + continue + + i["resolution"].update(dict(sample_height=sample_h, sample_width=sample_w)) + + else: + aspect = self.max_height / self.max_width + is_pick = filter_resolution( + height, + width, + max_h_div_w_ratio=self.hw_aspect_thr * aspect, + min_h_div_w_ratio=1 / self.hw_aspect_thr * aspect, + ) + if not is_pick: + if path.endswith(".mp4"): + cnt_vid_aspect_mismatch += 1 + elif path.endswith(".jpg"): + cnt_img_aspect_mismatch += 1 + continue + sample_h, sample_w = self.max_height, self.max_width + + i["resolution"].update(dict(sample_height=sample_h, sample_width=sample_w)) + + if path.endswith(".mp4"): + fps = i.get("fps", 24) + # max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration. + assert ( + "num_frames" in i + ), "Expect that each element in the provided datset should have a item named `num_frames`" + if i["num_frames"] > self.too_long_factor * ( + self.num_frames * fps / self.train_fps * self.speed_factor + ): # too long video is not suitable for this training stage (self.num_frames) + cnt_too_long += 1 + continue + + # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) + frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps + start_frame_idx = i.get("cut", [0])[0] + i["start_frame_idx"] = start_frame_idx + frame_indices = np.arange( + start_frame_idx, start_frame_idx + i["num_frames"], frame_interval + ).astype(int) + frame_indices = frame_indices[frame_indices < start_frame_idx + i["num_frames"]] + + # comment out it to enable dynamic frames training + if len(frame_indices) < self.num_frames and self.generator.random() < self.drop_short_ratio: + cnt_too_short += 1 + continue + + # too long video will be temporal-crop randomly + if len(frame_indices) > self.num_frames: + begin_index, end_index = self.temporal_sample(len(frame_indices)) + frame_indices = frame_indices[begin_index:end_index] + # frame_indices = frame_indices[:self.num_frames] # head crop + # to find a suitable end_frame_idx, to ensure we do not need pad video + end_frame_idx = find_closest_y( + len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size + ) + if end_frame_idx == -1: # too short that can not be encoded exactly by videovae + cnt_too_short += 1 + continue + frame_indices = frame_indices[:end_frame_idx] + + i["sample_frame_index"] = frame_indices.tolist() + + new_cap_list.append(i) + cnt_vid_after_filter += 1 + + elif path.endswith(".jpg"): # image + cnt_img_after_filter += 1 + i["sample_frame_index"] = [0] + new_cap_list.append(i) + + else: + raise NameError( + f"Unknown file extention {path.split('.')[-1]}, only support .mp4 for video and .jpg for image" + ) + + pre_define_shape = f"{len(i['sample_frame_index'])}x{sample_h}x{sample_w}" + sample_size.append(pre_define_shape) + # if shape_idx_dict.get(pre_define_shape, None) is None: + # shape_idx_dict[pre_define_shape] = [index] + # else: + # shape_idx_dict[pre_define_shape].append(index) + counter = Counter(sample_size) + counter_cp = counter + if not self.force_resolution and self.max_hxw is not None and self.min_hxw is not None: + assert all( + [np.prod(np.array(k.split("x")[1:]).astype(np.int32)) <= self.max_hxw for k in counter_cp.keys()] + ) + assert all( + [np.prod(np.array(k.split("x")[1:]).astype(np.int32)) >= self.min_hxw for k in counter_cp.keys()] + ) + + len_before_filter_major = len(sample_size) + filter_major_num = 4 * self.total_batch_size + new_cap_list, sample_size = zip( + *[[i, j] for i, j in zip(new_cap_list, sample_size) if counter[j] >= filter_major_num] + ) + for idx, shape in enumerate(sample_size): + if shape_idx_dict.get(shape, None) is None: + shape_idx_dict[shape] = [idx] + else: + shape_idx_dict[shape].append(idx) + cnt_filter_minority = len_before_filter_major - len(sample_size) + counter = Counter(sample_size) + + print( + f"no_cap: {cnt_no_cap}, no_resolution: {cnt_no_resolution}\n" + f"too_long: {cnt_too_long}, too_short: {cnt_too_short}\n" + f"cnt_img_res_mismatch_stride: {cnt_img_res_mismatch_stride}, cnt_vid_res_mismatch_stride: {cnt_vid_res_mismatch_stride}\n" + f"cnt_img_res_too_small: {cnt_img_res_too_small}, cnt_vid_res_too_small: {cnt_vid_res_too_small}\n" + f"cnt_img_aspect_mismatch: {cnt_img_aspect_mismatch}, cnt_vid_aspect_mismatch: {cnt_vid_aspect_mismatch}\n" + f"cnt_filter_minority: {cnt_filter_minority}\n" + f"cnt_no_existent: {cnt_no_existent}\n" + if self.filter_nonexistent + else "" + f"Counter(sample_size): {counter}\n" + f"cnt_vid: {cnt_vid}, cnt_vid_after_filter: {cnt_vid_after_filter}, use_ratio: {round(cnt_vid_after_filter/(cnt_vid+1e-6), 5)*100}%\n" + f"cnt_img: {cnt_img}, cnt_img_after_filter: {cnt_img_after_filter}, use_ratio: {round(cnt_img_after_filter/(cnt_img+1e-6), 5)*100}%\n" + f"before filter: {cnt}, after filter: {len(new_cap_list)}, use_ratio: {round(len(new_cap_list)/cnt, 5)*100}%" + ) + # import ipdb;ipdb.set_trace() + + if len(aesthetic_score) > 0: + stats_aesthetic = calculate_statistics(aesthetic_score) + print( + f"before filter: {cnt}, after filter: {len(new_cap_list)}\n" + f"aesthetic_score: {len(aesthetic_score)}, cnt_no_aesthetic: {cnt_no_aesthetic}\n" + f"{len([i for i in aesthetic_score if i>=5.75])} > 5.75, 4.5 > {len([i for i in aesthetic_score if i<=4.5])}\n" + f"Mean: {stats_aesthetic['mean']}, Var: {stats_aesthetic['variance']}, Std: {stats_aesthetic['std_dev']}\n" + f"Min: {stats_aesthetic['min']}, Max: {stats_aesthetic['max']}" + ) + + return new_cap_list, sample_size, shape_idx_dict def set_checkpoint(self, n_used_elements): for i in range(len(dataset_prog.n_used_elements)): @@ -185,16 +482,16 @@ def __len__(self): def __getitem__(self, idx): try: - data = self.get_data(idx) + future = self.executor.submit(self.get_data, idx) + data = future.result(timeout=self.timeout) + # data = self.get_data(idx) return data except Exception as e: - logger.info(f"Error with {e}") - # 打印异常堆栈 - if idx in dataset_prog.cap_list: - logger.info(f"Caught an exception! {dataset_prog.cap_list[idx]}") - # traceback.print_exc() - # traceback.print_stack() - return self.__getitem__(random.randint(0, self.__len__() - 1)) + if len(str(e)) < 2: + e = f"TimeoutError, {self.timeout}s timeout occur with {dataset_prog.cap_list[idx]['path']}" + print(f"Error with {e}") + index_cand = self.shape_idx_dict[self.sample_size[idx]] # pick same shape + return self.__getitem__(random.choice(index_cand)) def get_data(self, idx): path = dataset_prog.cap_list[idx]["path"] @@ -204,202 +501,233 @@ def get_data(self, idx): return self.get_image(idx) def get_video(self, idx): - video_path = dataset_prog.cap_list[idx]["path"] + video_data = dataset_prog.cap_list[idx] + video_path = video_data["path"] assert os.path.exists(video_path), f"file {video_path} do not exist!" - frame_indice = dataset_prog.cap_list[idx]["sample_frame_index"] - video = self.decord_read(video_path, predefine_num_frames=len(frame_indice)) # (T H W C) - - h, w = video.shape[1:3] - # NOTE: not suitable for 1:1 training in v1.3 - # assert h / w <= 17 / 16 and h / w >= 8 / 16, ( - # f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({video_path}) " - # + f"found ratio is {round(h / w, 2)} with the shape of {video.shape}" - # ) + sample_h = video_data["resolution"]["sample_height"] + sample_w = video_data["resolution"]["sample_width"] + if self.video_reader == "decord": + video = self.decord_read(video_data) + elif self.video_reader == "opencv": + video = self.opencv_read(video_data) + else: + NotImplementedError(f"Found {self.video_reader}, but support decord or opencv") + + h, w = video.shape[1:3] # (T, H, W, C) input_videos = {"image": video[0]} input_videos.update(dict([(f"image{i}", video[i + 1]) for i in range(len(video) - 1)])) output_videos = self.transform(**input_videos) video = np.stack([v for _, v in output_videos.items()], axis=0).transpose(3, 0, 1, 2) # T H W C -> C T H W - + assert ( + video.shape[2] == sample_h and video.shape[3] == sample_w + ), f"sample_h ({sample_h}), sample_w ({sample_w}), video ({video.shape})" # get token ids and attention mask if not self.return_text_emb if not self.return_text_emb: - text = dataset_prog.cap_list[idx]["cap"] + text = video_data["cap"] if not isinstance(text, list): text = [text] text = [random.choice(text)] + if video_data.get("aesthetic", None) is not None or video_data.get("aes", None) is not None: + aes = video_data.get("aesthetic", None) or video_data.get("aes", None) + text = [add_aesthetic_notice_video(text[0], aes)] + text = text_preprocessing(text, support_Chinese=self.support_Chinese) + + text = text if random.random() > self.cfg else "" - text = text_preprocessing(text, support_Chinese=self.support_Chinese) if random.random() > self.cfg else "" - text_tokens_and_mask = self.tokenizer( + text_tokens_and_mask_1 = self.tokenizer_1( text, max_length=self.model_max_length, padding="max_length", truncation=True, return_attention_mask=True, add_special_tokens=True, - return_tensors="np", + return_tensors="pt", ) - input_ids = text_tokens_and_mask["input_ids"] - cond_mask = text_tokens_and_mask["attention_mask"] - return dict(pixel_values=video, input_ids=input_ids, cond_mask=cond_mask) + input_ids_1 = text_tokens_and_mask_1["input_ids"] + cond_mask_1 = text_tokens_and_mask_1["attention_mask"] + + input_ids_2, cond_mask_2 = None, None + if self.tokenizer_2 is not None: + text_tokens_and_mask_2 = self.tokenizer_2( + text, + max_length=self.tokenizer_2.model_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids_2 = text_tokens_and_mask_2["input_ids"] + cond_mask_2 = text_tokens_and_mask_2["attention_mask"] + return dict( + pixel_values=video, + input_ids_1=input_ids_1, + cond_mask_1=cond_mask_1, + input_ids_2=input_ids_2, + cond_mask_2=cond_mask_2, + ) + else: - text_embed_paths = dataset_prog.cap_list[idx]["text_embed_path"] - text_embed_path = random.choice(text_embed_paths) - text_emb, cond_mask = self.parse_text_emb(text_embed_path) - return dict(pixel_values=video, input_ids=text_emb, cond_mask=cond_mask) + if "text_embed_path_1" in video_data: + text_embed_paths = video_data["text_embed_path_1"] + text_embed_path = random.choice(text_embed_paths) + text_emb_1, cond_mask_1 = self.parse_text_emb(text_embed_path) + text_emb_2, cond_mask_2 = None, None + if "text_embed_path_2" in video_data: + text_embed_paths = video_data["text_embed_path_2"] + text_embed_path = random.choice(text_embed_paths) + text_emb_2, cond_mask_2 = self.parse_text_emb(text_embed_path) + return dict( + pixel_values=video, + input_ids_1=text_emb_1, + cond_mask_1=cond_mask_1, + input_ids_2=text_emb_2, + cond_mask_2=cond_mask_2, + ) def get_image(self, idx): image_data = dataset_prog.cap_list[idx] # [{'path': path, 'cap': cap}, ...] - + sample_h = image_data["resolution"]["sample_height"] + sample_w = image_data["resolution"]["sample_width"] # import ipdb;ipdb.set_trace() image = Image.open(image_data["path"]).convert("RGB") # [h, w, c] image = np.array(image) # [h, w, c] - image = ( - self.transform_topcrop(image=image)["image"] - if "human_images" in image_data["path"] - else self.transform(image=image)["image"] - ) + image = self.transform(image=image)["image"] # [h, w, c] -> [c h w] -> [C 1 H W] image = image.transpose(2, 0, 1)[:, None, ...] + assert ( + image.shape[2] == sample_h and image.shape[3] == sample_w + ), f"image_data: {image_data}, but found image {image.shape}" # get token ids and attention mask if not self.return_text_emb if not self.return_text_emb: caps = image_data["cap"] if isinstance(image_data["cap"], list) else [image_data["cap"]] caps = [random.choice(caps)] + if image_data.get("aesthetic", None) is not None or image_data.get("aes", None) is not None: + aes = image_data.get("aesthetic", None) or image_data.get("aes", None) + caps = [add_aesthetic_notice_image(caps[0], aes)] text = text_preprocessing(caps, support_Chinese=self.support_Chinese) - input_ids, cond_mask = [], [] text = text if random.random() > self.cfg else "" - text_tokens_and_mask = self.tokenizer( + + text_tokens_and_mask_1 = self.tokenizer_1( text, max_length=self.model_max_length, padding="max_length", truncation=True, return_attention_mask=True, add_special_tokens=True, - return_tensors="np", + return_tensors="pt", + ) + input_ids_1 = text_tokens_and_mask_1["input_ids"] # 1, l + cond_mask_1 = text_tokens_and_mask_1["attention_mask"] # 1, l + + input_ids_2, cond_mask_2 = None, None + if self.tokenizer_2 is not None: + text_tokens_and_mask_2 = self.tokenizer_2( + text, + max_length=self.tokenizer_2.model_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids_2 = text_tokens_and_mask_2["input_ids"] # 1, l + cond_mask_2 = text_tokens_and_mask_2["attention_mask"] # 1, l + + return dict( + pixel_values=image, + input_ids_1=input_ids_1, + cond_mask_1=cond_mask_1, + input_ids_2=input_ids_2, + cond_mask_2=cond_mask_2, ) - input_ids = text_tokens_and_mask["input_ids"] # 1, l - cond_mask = text_tokens_and_mask["attention_mask"] # 1, l - return dict(pixel_values=image, input_ids=input_ids, cond_mask=cond_mask) else: - text_embed_paths = dataset_prog.cap_list[idx]["text_embed_path"] - text_embed_path = random.choice(text_embed_paths) - text_emb, cond_mask = self.parse_text_emb(text_embed_path) - return dict(pixel_values=image, input_ids=text_emb, cond_mask=cond_mask) + if "text_embed_path_1" in image_data: + text_embed_paths = image_data["text_embed_path_1"] + text_embed_path = random.choice(text_embed_paths) + text_emb_1, cond_mask_1 = self.parse_text_emb(text_embed_path) + text_emb_2, cond_mask_2 = None, None + if "text_embed_path_2" in image_data: + text_embed_paths = image_data["text_embed_path_2"] + text_embed_path = random.choice(text_embed_paths) + text_emb_2, cond_mask_2 = self.parse_text_emb(text_embed_path) + return dict( + pixel_values=image, + input_ids_1=text_emb_1, + cond_mask_1=cond_mask_1, + input_ids_2=text_emb_2, + cond_mask_2=cond_mask_2, + ) - def define_frame_index(self, cap_list): - new_cap_list = [] - sample_num_frames = [] - cnt_too_long = 0 - cnt_too_short = 0 - cnt_no_cap = 0 - cnt_no_resolution = 0 - cnt_resolution_mismatch = 0 - cnt_movie = 0 - cnt_img = 0 - for i in cap_list: - path = i["path"] - cap = i.get("cap", None) - # ======no caption===== - if cap is None: - cnt_no_cap += 1 - continue - if path.endswith(".mp4"): - # ======no fps and duration===== - duration = i.get("duration", None) - fps = i.get("fps", None) - if fps is None or duration is None: - continue + def decord_read(self, video_data): + path = video_data["path"] + predefine_frame_indice = video_data["sample_frame_index"] + start_frame_idx = video_data["start_frame_idx"] + clip_total_frames = video_data["num_frames"] + fps = video_data["fps"] + s_x, e_x, s_y, e_y = video_data.get("crop", [None, None, None, None]) - # ======resolution mismatch===== - resolution = i.get("resolution", None) - if resolution is None: - cnt_no_resolution += 1 - continue - else: - if resolution.get("height", None) is None or resolution.get("width", None) is None: - cnt_no_resolution += 1 - continue - height, width = i["resolution"]["height"], i["resolution"]["width"] - aspect = self.max_height / self.max_width - hw_aspect_thr = 2.0 #NOTE: for 1:1 frame training - is_pick = filter_resolution( - height, - width, - max_h_div_w_ratio=hw_aspect_thr * aspect, - min_h_div_w_ratio=1 / hw_aspect_thr * aspect, - ) - if not is_pick: - cnt_resolution_mismatch += 1 - continue + predefine_num_frames = len(predefine_frame_indice) + # decord_vr = decord.VideoReader(path, ctx=decord.cpu(0), num_threads=1) + decord_vr = DecordDecoder(path) - # import ipdb;ipdb.set_trace() - i["num_frames"] = int(fps * duration) - # max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration. - if i["num_frames"] > self.duration_threshold * ( - self.num_frames * fps / self.train_fps * self.speed_factor - ): # too long video is not suitable for this training stage (self.num_frames) - cnt_too_long += 1 - continue + frame_indices = self.get_actual_frame( + fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice + ) - # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) - frame_interval = fps / self.train_fps - start_frame_idx = 0 - frame_indices = np.arange(start_frame_idx, i["num_frames"], frame_interval).astype(int) - frame_indices = frame_indices[frame_indices < i["num_frames"]] + # video_data = decord_vr.get_batch(frame_indices).asnumpy() + # video_data = torch.from_numpy(video_data) + video_data = decord_vr.get_batch(frame_indices) + if video_data is not None: + if s_y is not None: + video_data = video_data[ + :, + s_y:e_y, + s_x:e_x, + :, + ] + else: + raise ValueError(f"Get video_data {video_data}") - # comment out it to enable dynamic frames training - if len(frame_indices) < self.num_frames and random.random() < self.drop_short_ratio: - cnt_too_short += 1 - continue + return video_data - # too long video will be temporal-crop randomly - if len(frame_indices) > self.num_frames: - begin_index, end_index = self.temporal_sample(len(frame_indices)) - frame_indices = frame_indices[begin_index:end_index] - # frame_indices = frame_indices[:self.num_frames] # head crop - # to find a suitable end_frame_idx, to ensure we do not need pad video - end_frame_idx = find_closest_y(len(frame_indices), vae_stride_t=4, model_ds_t=4) - if end_frame_idx == -1: # too short that can not be encoded exactly by videovae - if self.num_frames < 29: - logger.warning( - "The numbder of frames is less than 29, which is too short to be encoded by causal vae." - ) - cnt_too_short += 1 - continue - frame_indices = frame_indices[:end_frame_idx] - - i["sample_frame_index"] = frame_indices.tolist() - new_cap_list.append(i) - i["sample_num_frames"] = len(i["sample_frame_index"]) # will use in dataloader(group sampler) - sample_num_frames.append(i["sample_num_frames"]) - elif path.endswith(".jpg"): # image - cnt_img += 1 - new_cap_list.append(i) - i["sample_num_frames"] = 1 - sample_num_frames.append(i["sample_num_frames"]) - else: - raise NameError( - f"Unknown file extention {path.split('.')[-1]}, only support .mp4 for video and .jpg for image" - ) - # import ipdb;ipdb.set_trace() - logger.info( - f"no_cap: {cnt_no_cap}, too_long: {cnt_too_long}, too_short: {cnt_too_short}, " - f"no_resolution: {cnt_no_resolution}, resolution_mismatch: {cnt_resolution_mismatch}, " - f"Counter(sample_num_frames): {Counter(sample_num_frames)}, cnt_movie: {cnt_movie}, cnt_img: {cnt_img}, " - f"before filter: {len(cap_list)}, after filter: {len(new_cap_list)}" + def opencv_read(self, video_data): + path = video_data["path"] + predefine_frame_indice = video_data["sample_frame_index"] + start_frame_idx = video_data["start_frame_idx"] + clip_total_frames = video_data["num_frames"] + fps = video_data["fps"] + s_x, e_x, s_y, e_y = video_data.get("crop", [None, None, None, None]) + + predefine_num_frames = len(predefine_frame_indice) + cv2_vr = cv2.VideoCapture(path) + if not cv2_vr.isOpened(): + raise ValueError(f"can not open {path}") + frame_indices = self.get_actual_frame( + fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice ) - return new_cap_list, sample_num_frames - def decord_read(self, path, predefine_num_frames): - decord_vr = self.v_decoder(path) - total_frames = len(decord_vr) - fps = decord_vr.get_avg_fps() if decord_vr.get_avg_fps() > 0 else 30.0 - # import ipdb;ipdb.set_trace() + video_data = [] + for frame_idx in frame_indices: + cv2_vr.set(1, frame_idx) + _, frame = cv2_vr.read() + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + video_data.append(frame) # H, W, C + cv2_vr.release() + video_data = np.stack(video_data) # (T, H, W, C) + if s_y is not None: + video_data = video_data[:, s_y:e_y, s_x:e_x, :] + return video_data + + def get_actual_frame( + self, fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice + ): # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps - start_frame_idx = 0 - frame_indices = np.arange(start_frame_idx, total_frames, frame_interval).astype(int) - frame_indices = frame_indices[frame_indices < total_frames] - # import ipdb;ipdb.set_trace() + frame_indices = np.arange(start_frame_idx, start_frame_idx + clip_total_frames, frame_interval).astype(int) + frame_indices = frame_indices[frame_indices < start_frame_idx + clip_total_frames] + # speed up max_speed_factor = len(frame_indices) / self.num_frames if self.speed_factor > 1 and max_speed_factor > 1: @@ -416,22 +744,22 @@ def decord_read(self, path, predefine_num_frames): # frame_indices = frame_indices[:self.num_frames] # head crop # to find a suitable end_frame_idx, to ensure we do not need pad video - end_frame_idx = find_closest_y(len(frame_indices), vae_stride_t=4, model_ds_t=4) + end_frame_idx = find_closest_y(len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size) if end_frame_idx == -1: # too short that can not be encoded exactly by videovae raise IndexError( - f"video ({path}) has {total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})" + f"video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})" ) frame_indices = frame_indices[:end_frame_idx] if predefine_num_frames != len(frame_indices): raise ValueError( - f"predefine_num_frames ({predefine_num_frames}) is not equal with frame_indices ({len(frame_indices)})" + f"video ({path}) predefine_num_frames ({predefine_num_frames}) ({predefine_frame_indice}) is \ + not equal with frame_indices ({len(frame_indices)}) ({frame_indices})" ) if len(frame_indices) < self.num_frames and self.drop_short_ratio >= 1: raise IndexError( - f"video ({path}) has {total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})" + f"video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})" ) - video_data = decord_vr.get_batch(frame_indices).asnumpy() # (T, H, W, C) - return video_data + return frame_indices def get_text_embed_file_path(self, item): file_path = item["path"] @@ -462,33 +790,3 @@ def parse_text_emb(self, npz): mask = mask[None, ...] return text_emb, mask # (1, L, D), (1, L) - - def read_jsons(self, data): - cap_lists = [] - with open(data, "r") as f: - folder_anno = [i.strip().split(",") for i in f.readlines() if len(i.strip()) > 0] - for item in folder_anno: - if len(item) == 2: - folder, anno = item - elif len(item) == 3: - folder, text_emb_folder, anno = item - else: - raise ValueError(f"Expect to have two or three paths, but got {len(item)} input paths") - if self.return_text_emb: - assert ( - len(item) == 3 - ), "When returning text embeddings, please give three paths: video folder, text_embed folder, annotation file" - with open(anno, "r") as f: - sub_list = json.load(f) - logger.info(f"Building {anno}...") - for i in range(len(sub_list)): - if self.return_text_emb: - text_embeds_paths = self.get_text_embed_file_path(sub_list[i]) - sub_list[i]["text_embed_path"] = [opj(text_emb_folder, tp) for tp in text_embeds_paths] - sub_list[i]["path"] = opj(folder, sub_list[i]["path"]) - cap_lists += sub_list - return cap_lists - - def get_cap_list(self): - cap_lists = self.read_jsons(self.data) - return cap_lists diff --git a/examples/opensora_pku/opensora/dataset/transform.py b/examples/opensora_pku/opensora/dataset/transform.py index b1edd43fa0..12734c8d85 100644 --- a/examples/opensora_pku/opensora/dataset/transform.py +++ b/examples/opensora_pku/opensora/dataset/transform.py @@ -5,8 +5,8 @@ import albumentations import ftfy -from bs4 import BeautifulSoup import numpy as np +from bs4 import BeautifulSoup __all__ = ["create_video_transforms", "t5_text_preprocessing"] @@ -91,20 +91,22 @@ def center_crop_th_tw(image, th, tw, top_crop, **kwargs): cropped_image = crop(image, i, j, new_h, new_w) return cropped_image -def resize(image, h, w, interpolation_mode): - resize_func = albumentations.Resize(h, w, interpolation = interpolation_mode) +def resize(image, h, w, interpolation_mode): + resize_func = albumentations.Resize(h, w, interpolation=interpolation_mode) return resize_func(image=image)["image"] + def get_params(h, w, stride): th, tw = h // stride * stride, w // stride * stride - + i = (h - th) // 2 j = (w - tw) // 2 return i, j, th, tw + def spatial_stride_crop_video(image, stride, **kwargs): """ Args: @@ -113,31 +115,33 @@ def spatial_stride_crop_video(image, stride, **kwargs): numpy array: cropped video clip by stride. size is (OH, OW, C) """ - h, w = image.shape[:2] + h, w = image.shape[:2] i, j, h, w = get_params(h, w, stride) return crop(image, i, j, h, w) + def maxhxw_resize(image, max_hxw, interpolation_mode, **kwargs): - """ - First use the h*w, - then resize to the specified size - Args: - image (numpy array): Video clip to be cropped. Size is (H, W, C) - Returns: - numpy array: scale resized video clip. - """ - h, w = image.shape[:2] - if h * w > max_hxw: - scale_factor = np.sqrt(max_hxw / (h * w)) - tr_h = int(h * scale_factor) - tr_w = int(w * scale_factor) - else: - tr_h = h - tr_w = w - if h == tr_h and w == tr_w: - return image - resize_image = resize(image, tr_h, tr_w, interpolation_mode) - return resize_image + """ + First use the h*w, + then resize to the specified size + Args: + image (numpy array): Video clip to be cropped. Size is (H, W, C) + Returns: + numpy array: scale resized video clip. + """ + h, w = image.shape[:2] + if h * w > max_hxw: + scale_factor = np.sqrt(max_hxw / (h * w)) + tr_h = int(h * scale_factor) + tr_w = int(w * scale_factor) + else: + tr_h = h + tr_w = w + if h == tr_h and w == tr_w: + return image + resize_image = resize(image, tr_h, tr_w, interpolation_mode) + return resize_image + # create text transform(preprocess) bad_punct_regex = re.compile( @@ -311,3 +315,179 @@ def __call__(self, t, h, w): if self.extra_1: truncate_t = truncate_t + 1 return 0, truncate_t + + +keywords = [ + " man ", + " woman ", + " person ", + " people ", + "human", + " individual ", + " child ", + " kid ", + " girl ", + " boy ", +] +keywords += [i[:-1] + "s " for i in keywords] + +masking_notices = [ + "Note: The faces in this image are blurred.", + "This image contains faces that have been pixelated.", + "Notice: Faces in this image are masked.", + "Please be aware that the faces in this image are obscured.", + "The faces in this image are hidden.", + "This is an image with blurred faces.", + "The faces in this image have been processed.", + "Attention: Faces in this image are not visible.", + "The faces in this image are partially blurred.", + "This image has masked faces.", + "Notice: The faces in this picture have been altered.", + "This is a picture with obscured faces.", + "The faces in this image are pixelated.", + "Please note, the faces in this image have been blurred.", + "The faces in this photo are hidden.", + "The faces in this picture have been masked.", + "Note: The faces in this picture are altered.", + "This is an image where faces are not clear.", + "Faces in this image have been obscured.", + "This picture contains masked faces.", + "The faces in this image are processed.", + "The faces in this picture are not visible.", + "Please be aware, the faces in this photo are pixelated.", + "The faces in this picture have been blurred.", +] + +webvid_watermark_notices = [ + "This video has a faint Shutterstock watermark in the center.", + "There is a slight Shutterstock watermark in the middle of this video.", + "The video contains a subtle Shutterstock watermark in the center.", + "This video features a light Shutterstock watermark at its center.", + "A faint Shutterstock watermark is present in the middle of this video.", + "There is a mild Shutterstock watermark at the center of this video.", + "This video has a slight Shutterstock watermark in the middle.", + "You can see a faint Shutterstock watermark in the center of this video.", + "A subtle Shutterstock watermark appears in the middle of this video.", + "This video includes a light Shutterstock watermark at its center.", +] + + +high_aesthetic_score_notices_video = [ + "This video has a high aesthetic quality.", + "The beauty of this video is exceptional.", + "This video scores high in aesthetic value.", + "With its harmonious colors and balanced composition.", + "This video ranks highly for aesthetic quality", + "The artistic quality of this video is excellent.", + "This video is rated high for beauty.", + "The aesthetic quality of this video is impressive.", + "This video has a top aesthetic score.", + "The visual appeal of this video is outstanding.", +] + +low_aesthetic_score_notices_video = [ + "This video has a low aesthetic quality.", + "The beauty of this video is minimal.", + "This video scores low in aesthetic appeal.", + "The aesthetic quality of this video is below average.", + "This video ranks low for beauty.", + "The artistic quality of this video is lacking.", + "This video has a low score for aesthetic value.", + "The visual appeal of this video is low.", + "This video is rated low for beauty.", + "The aesthetic quality of this video is poor.", +] + + +high_aesthetic_score_notices_image = [ + "This image has a high aesthetic quality.", + "The beauty of this image is exceptional", + "This photo scores high in aesthetic value.", + "With its harmonious colors and balanced composition.", + "This image ranks highly for aesthetic quality.", + "The artistic quality of this photo is excellent.", + "This image is rated high for beauty.", + "The aesthetic quality of this image is impressive.", + "This photo has a top aesthetic score.", + "The visual appeal of this image is outstanding.", +] + +low_aesthetic_score_notices_image = [ + "This image has a low aesthetic quality.", + "The beauty of this image is minimal.", + "This image scores low in aesthetic appeal.", + "The aesthetic quality of this image is below average.", + "This image ranks low for beauty.", + "The artistic quality of this image is lacking.", + "This image has a low score for aesthetic value.", + "The visual appeal of this image is low.", + "This image is rated low for beauty.", + "The aesthetic quality of this image is poor.", +] + +high_aesthetic_score_notices_image_human = [ + "High-quality image with visible human features and high aesthetic score.", + "Clear depiction of an individual in a high-quality image with top aesthetics.", + "High-resolution photo showcasing visible human details and high beauty rating.", + "Detailed, high-quality image with well-defined human subject and strong aesthetic appeal.", + "Sharp, high-quality portrait with clear human features and high aesthetic value.", + "High-quality image featuring a well-defined human presence and exceptional aesthetics.", + "Visible human details in a high-resolution photo with a high aesthetic score.", + "Clear, high-quality image with prominent human subject and superior aesthetic rating.", + "High-quality photo capturing a visible human with excellent aesthetics.", + "Detailed, high-quality image of a human with high visual appeal and aesthetic value.", +] + + +def calculate_statistics(data): + if len(data) == 0: + return None + data = np.array(data) + mean = np.mean(data) + variance = np.var(data) + std_dev = np.std(data) + minimum = np.min(data) + maximum = np.max(data) + + return {"mean": mean, "variance": variance, "std_dev": std_dev, "min": minimum, "max": maximum} + + +def maxhwresize(ori_height, ori_width, max_hxw): + if ori_height * ori_width > max_hxw: + scale_factor = np.sqrt(max_hxw / (ori_height * ori_width)) + new_height = int(ori_height * scale_factor) + new_width = int(ori_width * scale_factor) + else: + new_height = ori_height + new_width = ori_width + return new_height, new_width + + +def add_aesthetic_notice_video(caption, aesthetic_score): + if aesthetic_score <= 4.25: + notice = random.choice(low_aesthetic_score_notices_video) + return random.choice([caption + " " + notice, notice + " " + caption]) + if aesthetic_score >= 5.75: + notice = random.choice(high_aesthetic_score_notices_video) + return random.choice([caption + " " + notice, notice + " " + caption]) + return caption + + +def add_aesthetic_notice_image(caption, aesthetic_score): + if aesthetic_score <= 4.25: + notice = random.choice(low_aesthetic_score_notices_image) + return random.choice([caption + " " + notice, notice + " " + caption]) + if aesthetic_score >= 5.75: + notice = random.choice(high_aesthetic_score_notices_image) + return random.choice([caption + " " + notice, notice + " " + caption]) + return caption + + +def add_high_aesthetic_notice_image(caption): + notice = random.choice(high_aesthetic_score_notices_image) + return random.choice([caption + " " + notice, notice + " " + caption]) + + +def add_high_aesthetic_notice_image_human(caption): + notice = random.choice(high_aesthetic_score_notices_image_human) + return random.choice([caption + " " + notice, notice + " " + caption]) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 0b2c775745..5c8e71398b 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -26,7 +26,7 @@ from opensora.npu_config import npu_config from opensora.train.commons import create_loss_scaler, parse_args from opensora.utils.callbacks import EMAEvalSwapCallback, PerfRecorderCallback -from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler +from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.ema import EMA from opensora.utils.message_utils import print_banner from opensora.utils.utils import get_precision, save_diffusers_json @@ -299,36 +299,22 @@ def main(args): initial_global_step_for_sampler = args.trained_data_global_step else: initial_global_step_for_sampler = 0 + total_batch_size = args.train_batch_size * device_num * args.gradient_accumulation_steps + total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size + args.total_batch_size = total_batch_size if args.max_hxw is not None and args.min_hxw is None: args.min_hxw = args.max_hxw // 4 train_dataset = getdataset(args, dataset_file=args.data) - sampler = ( - LengthGroupedBatchSampler( - args.train_batch_size, - world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), - lengths=train_dataset.lengths, - group_frame=args.group_frame, # v1.2 - group_resolution=args.group_resolution, # v1.2 - initial_global_step_for_sampler=initial_global_step_for_sampler, # TODO: use in v1.3 - group_data=args.group_data, # TODO: use in v1.3 - ) - if (args.group_frame or args.group_resolution) # v1.2 - else None # v1.2 - ) - collate_fn = Collate( + sampler = LengthGroupedSampler( args.train_batch_size, - args.group_frame, - args.group_resolution, - args.max_height, - args.max_width, - args.ae_stride, - args.ae_stride_t, - args.patch_size, - args.patch_size_t, - args.num_frames, - args.use_image_num, + world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), + gradient_accumulation_size=args.gradient_accumulation_steps, + initial_global_step=initial_global_step_for_sampler, + lengths=train_dataset.lengths, + group_data=args.group_data, ) + collate_fn = Collate(args) dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, @@ -353,17 +339,15 @@ def main(args): assert os.path.exists(args.val_data), f"validation dataset file must exist, but got {args.val_data}" print_banner("Validation dataset Loading...") val_dataset = getdataset(args, dataset_file=args.val_data) - sampler = ( - LengthGroupedBatchSampler( - args.val_batch_size, - world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), - lengths=val_dataset.lengths, - group_frame=args.group_frame, - group_resolution=args.group_resolution, - ) - if (args.group_frame or args.group_resolution) - else None + sampler = LengthGroupedSampler( + args.val_batch_size, + world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), + lengths=val_dataset.lengths, + gradient_accumulation_size=args.gradient_accumulation_steps, + initial_global_step=initial_global_step_for_sampler, + group_data=args.group_data, ) + collate_fn = Collate( args.val_batch_size, args.group_frame, @@ -674,11 +658,6 @@ def main(args): callback.append(ProfilerCallbackEpoch(2, 2, "./profile_data")) # Train! - assert ( - args.train_sp_batch_size == 1 - ), "Do not support train_sp_batch_size other than 1. Please set `--train_sp_batch_size 1`" - total_batch_size = args.train_batch_size * device_num * args.gradient_accumulation_steps - total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size # 5. log and save config if rank_id == 0: @@ -764,7 +743,12 @@ def parse_t2v_train_args(parser): parser.add_argument("--hw_stride", type=int, default=32) parser.add_argument("--force_resolution", action="store_true") parser.add_argument("--trained_data_global_step", type=int, default=None) - parser.add_argument("--use_decord", action="store_true") + parser.add_argument( + "--use_decord", + type=str2bool, + default=True, + help="whether to use decord to load videos. If not, use opencv to load videos.", + ) # text encoder & vae & diffusion model parser.add_argument("--vae_fp32", action="store_true") diff --git a/examples/opensora_pku/opensora/utils/dataset_utils.py b/examples/opensora_pku/opensora/utils/dataset_utils.py index be4c472977..27caeb88eb 100644 --- a/examples/opensora_pku/opensora/utils/dataset_utils.py +++ b/examples/opensora_pku/opensora/utils/dataset_utils.py @@ -1,6 +1,6 @@ import math import random -from collections import Counter +from collections import Counter, defaultdict from typing import List, Optional import decord @@ -128,61 +128,77 @@ def pad_to_multiple(number, ds_stride): class Collate: - def __init__( - self, - batch_size, - group_frame, - group_resolution, - max_height, - max_width, - ae_stride, - ae_stride_t, - patch_size, - patch_size_t, - num_frames, - use_image_num, - ): - self.batch_size = batch_size - self.group_frame = group_frame - self.group_resolution = group_resolution + def __init__(self, args): + self.batch_size = args.train_batch_size + self.group_data = args.group_data + self.force_resolution = args.force_resolution - self.max_height = max_height - self.max_width = max_width - self.ae_stride = ae_stride + self.max_height = args.max_height + self.max_width = args.max_width + self.ae_stride = args.ae_stride - self.ae_stride_t = ae_stride_t + self.ae_stride_t = args.ae_stride_t self.ae_stride_thw = (self.ae_stride_t, self.ae_stride, self.ae_stride) - self.patch_size = patch_size - self.patch_size_t = patch_size_t + self.patch_size = args.patch_size + self.patch_size_t = args.patch_size_t - self.num_frames = num_frames - self.use_image_num = use_image_num + self.num_frames = args.num_frames self.max_thw = (self.num_frames, self.max_height, self.max_width) def package(self, batch): batch_tubes = [i["pixel_values"] for i in batch] # b [c t h w] - input_ids = [i["input_ids"] for i in batch] # b [1 l] - cond_mask = [i["cond_mask"] for i in batch] # b [1 l] - return batch_tubes, input_ids, cond_mask + input_ids_1 = [i["input_ids_1"] for i in batch] # b [1 l] + cond_mask_1 = [i["cond_mask_1"] for i in batch] # b [1 l] + input_ids_2 = [i["input_ids_2"] for i in batch] # b [1 l] + cond_mask_2 = [i["cond_mask_2"] for i in batch] # b [1 l] + assert all([i is None for i in input_ids_2]) or all([i is not None for i in input_ids_2]) + assert all([i is None for i in cond_mask_2]) or all([i is not None for i in cond_mask_2]) + if all([i is None for i in input_ids_2]): + input_ids_2 = None + if all([i is None for i in cond_mask_2]): + cond_mask_2 = None + return batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 def __call__(self, batch): - batch_tubes, input_ids, cond_mask = self.package(batch) + batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = self.package(batch) ds_stride = self.ae_stride * self.patch_size t_ds_stride = self.ae_stride_t * self.patch_size_t - pad_batch_tubes, attention_mask, input_ids, cond_mask = self.process( - batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, self.max_thw, self.ae_stride_thw + pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = self.process( + batch_tubes, + input_ids_1, + cond_mask_1, + input_ids_2, + cond_mask_2, + t_ds_stride, + ds_stride, + self.max_thw, + self.ae_stride_thw, ) - # assert not np.any(np.isnan(pad_batch_tubes)), 'after pad_batch_tubes' - return pad_batch_tubes, attention_mask, input_ids, cond_mask + assert not np.any(np.isnan(pad_batch_tubes)), "after pad_batch_tubes" + if input_ids_2 is not None and cond_mask_2 is not None: + return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 + else: + return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1 - def process(self, batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, max_thw, ae_stride_thw): + def process( + self, + batch_tubes, + input_ids_1, + cond_mask_1, + input_ids_2, + cond_mask_2, + t_ds_stride, + ds_stride, + max_thw, + ae_stride_thw, + ): # pad to max multiple of ds_stride batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)] assert len(batch_input_size) == self.batch_size - if self.group_frame or self.group_resolution or self.batch_size == 1: # + if self.group_data or self.batch_size == 1: # len_each_batch = batch_input_size idx_length_dict = dict([*zip(list(range(self.batch_size)), len_each_batch)]) count_dict = Counter(len_each_batch) @@ -195,13 +211,25 @@ def process(self, batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, max random_select_batch = [ random.choice(candidate_batch) for _ in range(len(len_each_batch) - len(candidate_batch)) ] - # print(batch_input_size, idx_length_dict, count_dict, sorted_by_value, pick_length, candidate_batch, random_select_batch) + print( + batch_input_size, + idx_length_dict, + count_dict, + sorted_by_value, + pick_length, + candidate_batch, + random_select_batch, + ) pick_idx = candidate_batch + random_select_batch batch_tubes = [batch_tubes[i] for i in pick_idx] batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)] - input_ids = [input_ids[i] for i in pick_idx] # b [1, l] - cond_mask = [cond_mask[i] for i in pick_idx] # b [1, l] + input_ids_1 = [input_ids_1[i] for i in pick_idx] # b [1, l] + cond_mask_1 = [cond_mask_1[i] for i in pick_idx] # b [1, l] + if input_ids_2 is not None: + input_ids_2 = [input_ids_2[i] for i in pick_idx] # b [1, l] + if cond_mask_2 is not None: + cond_mask_2 = [cond_mask_2[i] for i in pick_idx] # b [1, l] for i in range(1, self.batch_size): assert batch_input_size[0] == batch_input_size[i] @@ -217,7 +245,6 @@ def process(self, batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, max ) pad_max_t = pad_max_t + 1 - self.ae_stride_t each_pad_t_h_w = [[pad_max_t - i.shape[1], pad_max_h - i.shape[2], pad_max_w - i.shape[3]] for i in batch_tubes] - pad_batch_tubes = [ np.pad(im, [[0, 0]] * (len(im.shape) - 3) + [[0, pad_t], [0, pad_h], [0, pad_w]], constant_values=0) for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes) @@ -248,81 +275,61 @@ def process(self, batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, max for i in valid_latent_size ] attention_mask = np.stack(attention_mask, axis=0) # b t h w - if self.batch_size == 1 or self.group_frame or self.group_resolution: + if self.batch_size == 1 or self.group_data: + if not np.all(attention_mask.astype(np.bool_)): + print( + batch_input_size, + (max_t, max_h, max_w), + (pad_max_t, pad_max_h, pad_max_w), + each_pad_t_h_w, + max_latent_size, + valid_latent_size, + ) assert np.all(attention_mask.astype(np.bool_)) - input_ids = np.stack(input_ids, axis=0) # b 1 l - cond_mask = np.stack(cond_mask, axis=0) # b 1 l - if input_ids.dtype == np.int64: - input_ids = input_ids.astype(np.int32) - if attention_mask.dtype == np.int64: - attention_mask = attention_mask.astype(np.int32) - if cond_mask.dtype == np.int64: - cond_mask = cond_mask.astype(np.int32) - return pad_batch_tubes, attention_mask, input_ids, cond_mask + input_ids_1 = np.stack(input_ids_1) # b 1 l + cond_mask_1 = np.stack(cond_mask_1) # b 1 l + input_ids_2 = np.stack(input_ids_2) if input_ids_2 is not None else input_ids_2 # b 1 l + cond_mask_2 = np.stack(cond_mask_2) if cond_mask_2 is not None else cond_mask_2 # b 1 l + return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 -def split_to_even_chunks(indices, lengths, num_chunks, batch_size): - """ - Split a list of indices into `chunks` chunks of roughly equal lengths. - """ - if len(indices) % num_chunks != 0: - chunks = [indices[i::num_chunks] for i in range(num_chunks)] - else: - num_indices_per_chunk = len(indices) // num_chunks - - chunks = [[] for _ in range(num_chunks)] - chunks_lengths = [0 for _ in range(num_chunks)] - for index in indices: - shortest_chunk = chunks_lengths.index(min(chunks_lengths)) - chunks[shortest_chunk].append(index) - chunks_lengths[shortest_chunk] += lengths[index] - if len(chunks[shortest_chunk]) == num_indices_per_chunk: - chunks_lengths[shortest_chunk] = float("inf") - # return chunks - - pad_chunks = [] - for idx, chunk in enumerate(chunks): - if batch_size != len(chunk): - assert batch_size > len(chunk) - if len(chunk) != 0: - chunk = chunk + [random.choice(chunk) for _ in range(batch_size - len(chunk))] - else: - chunk = random.choice(pad_chunks) - print(chunks[idx], "->", chunk) - pad_chunks.append(chunk) - return pad_chunks - - -def group_frame_fun(indices, lengths): - # sort by num_frames - indices.sort(key=lambda i: lengths[i], reverse=True) - return indices +def group_data_fun(lengths, generator=None): + # counter is decrease order + counter = Counter(lengths) # counter {'1x256x256': 3, ''} lengths ['1x256x256', '1x256x256', '1x256x256', ...] + grouped_indices = defaultdict(list) + for idx, item in enumerate(lengths): # group idx to a list + grouped_indices[item].append(idx) + grouped_indices = dict(grouped_indices) # {'1x256x256': [0, 1, 2], ...} + sorted_indices = [grouped_indices[item] for (item, _) in sorted(counter.items(), key=lambda x: x[1], reverse=True)] -def group_resolution_fun(indices): - raise NotImplementedError - return indices + # shuffle in each group + shuffle_sorted_indices = [] + for indice in sorted_indices: + shuffle_idx = generator.permutation(len(indice)).tolist() + shuffle_sorted_indices.extend([indice[idx] for idx in shuffle_idx]) + return shuffle_sorted_indices -def group_frame_and_resolution_fun(indices): - raise NotImplementedError - return indices - - -def last_group_frame_fun(shuffled_megabatches, lengths): +def last_group_data_fun(shuffled_megabatches, lengths): + # lengths ['1x256x256', '1x256x256', '1x256x256' ...] re_shuffled_megabatches = [] # print('shuffled_megabatches', len(shuffled_megabatches)) for i_megabatch, megabatch in enumerate(shuffled_megabatches): re_megabatch = [] for i_batch, batch in enumerate(megabatch): assert len(batch) != 0 - len_each_batch = [lengths[i] for i in batch] - idx_length_dict = dict([*zip(batch, len_each_batch)]) - count_dict = Counter(len_each_batch) + + len_each_batch = [lengths[i] for i in batch] # ['1x256x256', '1x256x256'] + idx_length_dict = dict([*zip(batch, len_each_batch)]) # {0: '1x256x256', 100: '1x256x256'} + count_dict = Counter(len_each_batch) # {'1x256x256': 2} or {'1x256x256': 1, '1x768x256': 1} if len(count_dict) != 1: - sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1]) + sorted_by_value = sorted( + count_dict.items(), key=lambda item: item[1] + ) # {'1x256x256': 1, '1x768x256': 1} + # import ipdb;ipdb.set_trace() # print(batch, idx_length_dict, count_dict, sorted_by_value) pick_length = sorted_by_value[-1][0] # the highest frequency candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length] @@ -332,6 +339,12 @@ def last_group_frame_fun(shuffled_megabatches, lengths): # print(batch, idx_length_dict, count_dict, sorted_by_value, pick_length, candidate_batch, random_select_batch) batch = candidate_batch + random_select_batch # print(batch) + + for i in range(1, len(batch) - 1): + # if not lengths[batch[0]] == lengths[batch[i]]: + # print(batch, [lengths[i] for i in batch]) + # import ipdb;ipdb.set_trace() + assert lengths[batch[0]] == lengths[batch[i]] re_megabatch.append(batch) re_shuffled_megabatches.append(re_megabatch) @@ -343,48 +356,159 @@ def last_group_frame_fun(shuffled_megabatches, lengths): return re_shuffled_megabatches -def last_group_resolution_fun(indices): - raise NotImplementedError - return indices - +def split_to_even_chunks(megabatch, lengths, world_size, batch_size): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + # batch_size=2, world_size=2 + # [1, 2, 3, 4] -> [[1, 2], [3, 4]] + # [1, 2, 3] -> [[1, 2], [3]] + # [1, 2] -> [[1], [2]] + # [1] -> [[1], []] + chunks = [megabatch[i::world_size] for i in range(world_size)] -def last_group_frame_and_resolution_fun(indices): - raise NotImplementedError - return indices + pad_chunks = [] + for idx, chunk in enumerate(chunks): + if batch_size != len(chunk): + assert batch_size > len(chunk) + if len(chunk) != 0: # [[1, 2], [3]] -> [[1, 2], [3, 3]] + chunk = chunk + [random.choice(chunk) for _ in range(batch_size - len(chunk))] + else: + chunk = random.choice(pad_chunks) # [[1], []] -> [[1], [1]] + print(chunks[idx], "->", chunk) + pad_chunks.append(chunk) + return pad_chunks def get_length_grouped_indices( - lengths, batch_size, world_size, generator=None, group_frame=False, group_resolution=False, seed=42 + lengths, + batch_size, + world_size, + gradient_accumulation_size, + initial_global_step, + generator=None, + group_data=False, + seed=42, ): - # We need to use numpy for the random part as a distributed sampler will set the random seed if generator is None: generator = np.random.default_rng(seed) # every rank will generate a fixed order but random index + # print('lengths', lengths) + + if group_data: + indices = group_data_fun(lengths, generator) + else: + indices = generator.permutation(len(lengths)).tolist() + # print('indices', len(indices)) + + # print('sort indices', len(indices)) + # print('sort indices', indices) + # print('sort lengths', [lengths[i] for i in indices]) - indices = generator.permutation(len(lengths)).tolist() - if group_frame and not group_resolution: - indices = group_frame_fun(indices, lengths) - elif not group_frame and group_resolution: - indices = group_resolution_fun(indices) - elif group_frame and group_resolution: - indices = group_frame_and_resolution_fun(indices) megabatch_size = world_size * batch_size megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)] + # import ipdb;ipdb.set_trace() + # print('megabatches', len(megabatches)) + # print('\nmegabatches', megabatches) + # megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + # import ipdb;ipdb.set_trace() + # print('sort megabatches', len(megabatches)) + # megabatches_len = [[lengths[i] for i in megabatch] for megabatch in megabatches] + # print(f'\nrank {accelerator.process_index} sorted megabatches_len', megabatches_len[0], megabatches_len[1], megabatches_len[-2], megabatches_len[-1]) + # import ipdb;ipdb.set_trace() + megabatches = [split_to_even_chunks(megabatch, lengths, world_size, batch_size) for megabatch in megabatches] + # import ipdb;ipdb.set_trace() + # print('nsplit_to_even_chunks megabatches', len(megabatches)) + # print('\nsplit_to_even_chunks megabatches', megabatches) + # split_to_even_chunks_len = [[[lengths[i] for i in batch] for batch in megabatch] for megabatch in megabatches] + # print(f'\nrank {accelerator.process_index} split_to_even_chunks_len', split_to_even_chunks_len[0], + # split_to_even_chunks_len[1], split_to_even_chunks_len[-2], split_to_even_chunks_len[-1]) + # print('\nsplit_to_even_chunks len', split_to_even_chunks_len) + # return [i for megabatch in megabatches for batch in megabatch for i in batch] + + indices_mega = generator.permutation(len(megabatches)).tolist() + # print(f'rank {accelerator.process_index} seed {seed}, len(megabatches) {len(megabatches)}, indices_mega, {indices_mega[:50]}') + shuffled_megabatches = [megabatches[i] for i in indices_mega] + # shuffled_megabatches_len = [ + # [[lengths[i] for i in batch] for batch in megabatch] for megabatch in shuffled_megabatches + # ] + # print(f'\nrank {accelerator.process_index} sorted shuffled_megabatches_len', shuffled_megabatches_len[0], + # shuffled_megabatches_len[1], shuffled_megabatches_len[-2], shuffled_megabatches_len[-1]) + + # import ipdb;ipdb.set_trace() + # print('shuffled_megabatches', len(shuffled_megabatches)) + if group_data: + shuffled_megabatches = last_group_data_fun(shuffled_megabatches, lengths) + # group_shuffled_megabatches_len = [ + # [[lengths[i] for i in batch] for batch in megabatch] for megabatch in shuffled_megabatches + # ] + # print(f'\nrank {accelerator.process_index} group_shuffled_megabatches_len', group_shuffled_megabatches_len[0], + # group_shuffled_megabatches_len[1], group_shuffled_megabatches_len[-2], group_shuffled_megabatches_len[-1]) + + # import ipdb;ipdb.set_trace() + initial_global_step = initial_global_step * gradient_accumulation_size + # print('shuffled_megabatches', len(shuffled_megabatches)) + # print('have been trained idx:', len(shuffled_megabatches[:initial_global_step])) + # print('shuffled_megabatches[:10]', shuffled_megabatches[:10]) + # print('have been trained idx:', shuffled_megabatches[:initial_global_step]) + shuffled_megabatches = shuffled_megabatches[initial_global_step:] + print(f"Skip the data of {initial_global_step} step!") + # print('after shuffled_megabatches', len(shuffled_megabatches)) + # print('after shuffled_megabatches[:10]', shuffled_megabatches[:10]) + + # print('\nshuffled_megabatches', shuffled_megabatches) + # import ipdb;ipdb.set_trace() + # print('\nshuffled_megabatches len', [[i, lengths[i]] for megabatch in shuffled_megabatches for batch in megabatch for i in batch]) + # return [i for megabatch in shuffled_megabatches for batch in megabatch for i in batch] # return epoch indices in a list + return [batch for megabatch in shuffled_megabatches for batch in megabatch] # return batch indices (list of lists) + + +class LengthGroupedSampler: + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ - megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + def __init__( + self, + batch_size: int, + world_size: int, + gradient_accumulation_size: int, + initial_global_step: int, + lengths: Optional[List[int]] = None, + group_data=False, + generator=None, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") - megabatches = [split_to_even_chunks(megabatch, lengths, world_size, batch_size) for megabatch in megabatches] + self.batch_size = batch_size + self.world_size = world_size + self.initial_global_step = initial_global_step + self.gradient_accumulation_size = gradient_accumulation_size + self.lengths = lengths + self.group_data = group_data + self.generator = generator + # print('self.lengths, self.initial_global_step, self.batch_size, self.world_size, self.gradient_accumulation_size', + # len(self.lengths), self.initial_global_step, self.batch_size, self.world_size, self.gradient_accumulation_size) + + def __len__(self): + return ( + len(self.lengths) + - self.initial_global_step * self.batch_size * self.world_size * self.gradient_accumulation_size + ) - indices = generator.permutation(len(megabatches)).tolist() - shuffled_megabatches = [megabatches[i] for i in indices] - if group_frame and not group_resolution: - shuffled_megabatches = last_group_frame_fun(shuffled_megabatches, lengths) - elif not group_frame and group_resolution: - shuffled_megabatches = last_group_resolution_fun(shuffled_megabatches, indices) - elif group_frame and group_resolution: - shuffled_megabatches = last_group_frame_and_resolution_fun(shuffled_megabatches, indices) + def __iter__(self): + indices = get_length_grouped_indices( + self.lengths, + self.batch_size, + self.world_size, + self.gradient_accumulation_size, + self.initial_global_step, + group_data=self.group_data, + generator=self.generator, + ) - # return [i for megabatch in shuffled_megabatches for batch in megabatch for i in batch] - return [batch for megabatch in shuffled_megabatches for batch in megabatch] # return batch indices + return iter(indices) class LengthGroupedBatchSampler: @@ -401,7 +525,7 @@ def __init__( initial_global_step_for_sampler: int = 0, group_frame=False, group_resolution=False, - group_data = False, + group_data=False, generator=None, ): if lengths is None: diff --git a/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt b/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt index af4c65966e..2b47609591 100644 --- a/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt +++ b/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt @@ -1 +1 @@ -datasets/Open-Sora-Plan-v1.2.0,datasets/Open-Sora-Plan-v1.2.0/mixkit_emb-len=512,datasets/Open-Sora-Plan-v1.2.0/v1.1.0_HQ_part1_Traffic_train.json +datasets/Open-Sora-Plan-v1.3.0/videos_16/,datasets/Open-Sora-Plan-v1.3.0/videos_16_emb-len=512,datasets/Open-Sora-Plan-v1.3.0/opendv_vid16.json diff --git a/examples/opensora_pku/tests/test_data.py b/examples/opensora_pku/tests/test_data.py index 601bbf8de8..b7efd0c9d2 100644 --- a/examples/opensora_pku/tests/test_data.py +++ b/examples/opensora_pku/tests/test_data.py @@ -13,7 +13,7 @@ from opensora.models.causalvideovae import ae_stride_config from opensora.models.diffusion import Diffusion_models from opensora.train.commons import parse_args -from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler +from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.message_utils import print_banner from mindone.utils.config import str2bool @@ -47,31 +47,26 @@ def load_dataset_and_dataloader(args, device_num=1, rank_id=0): assert args.dataset == "t2v", "Support t2v dataset only." print_banner("Dataset Loading") # Setup data: + if args.trained_data_global_step is not None: + initial_global_step_for_sampler = args.trained_data_global_step + else: + initial_global_step_for_sampler = 0 + total_batch_size = args.train_batch_size * device_num * args.gradient_accumulation_steps + total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size + args.total_batch_size = total_batch_size + if args.max_hxw is not None and args.min_hxw is None: + args.min_hxw = args.max_hxw // 4 + train_dataset = getdataset(args, dataset_file=args.data) - sampler = ( - LengthGroupedBatchSampler( - args.train_batch_size, - world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), - lengths=train_dataset.lengths, - group_frame=args.group_frame, - group_resolution=args.group_resolution, - ) - if (args.group_frame or args.group_resolution) - else None - ) - collate_fn = Collate( + sampler = LengthGroupedSampler( args.train_batch_size, - args.group_frame, - args.group_resolution, - args.max_height, - args.max_width, - args.ae_stride, - args.ae_stride_t, - args.patch_size, - args.patch_size_t, - args.num_frames, - args.use_image_num, + world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), + gradient_accumulation_size=args.gradient_accumulation_steps, + initial_global_step=initial_global_step_for_sampler, + lengths=train_dataset.lengths, + group_data=args.group_data, ) + collate_fn = Collate(args) dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, @@ -101,7 +96,12 @@ def parse_t2v_train_args(parser): parser.add_argument("--hw_stride", type=int, default=32) parser.add_argument("--force_resolution", action="store_true") parser.add_argument("--trained_data_global_step", type=int, default=None) - parser.add_argument("--use_decord", action="store_true") + parser.add_argument( + "--use_decord", + type=str2bool, + default=True, + help="whether to use decord to load videos. If not, use opencv to load videos.", + ) # text encoder & vae & diffusion model parser.add_argument("--vae_fp32", action="store_true") diff --git a/examples/opensora_pku/tests/test_data.sh b/examples/opensora_pku/tests/test_data.sh index 72feb24a23..27e2d6aa8d 100644 --- a/examples/opensora_pku/tests/test_data.sh +++ b/examples/opensora_pku/tests/test_data.sh @@ -3,19 +3,20 @@ python tests/test_data.py \ --text_encoder_name_1 google/mt5-xxl \ --dataset t2v \ --num_frames 93 \ - --data "scripts/train_data/merge_data_mixkit.txt" \ + --data "scripts/train_data/video_data_v1_2.txt" \ --cache_dir "./" \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "LanguageBind/Open-Sora-Plan-v1.3.0/vae" \ --sample_rate 1 \ - --max_height 352 \ + --max_height 640 \ --max_width 640 \ + --max_hxw 409600 \ --train_fps 16 \ - --force_resolution \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ - --train_batch_size=8 \ - --dataloader_num_workers 20 \ + --train_batch_size=1 \ + --dataloader_num_workers 8 \ --output_dir="test_data/" \ --model_max_length 512 \ + # --force_resolution \ From 48b9c820e552423f74600e4dd537029ee2e585db Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Dec 2024 10:24:52 +0800 Subject: [PATCH 118/133] set dataset sink mode --- .../scripts/text_condition/multi-devices/train_t2i_stage1.sh | 2 +- .../scripts/text_condition/multi-devices/train_t2v_stage2.sh | 2 +- .../scripts/text_condition/multi-devices/train_t2v_stage3.sh | 2 +- .../scripts/text_condition/single-device/train_t2i_stage1.sh | 2 +- .../scripts/text_condition/single-device/train_t2v_stage2.sh | 2 +- .../scripts/text_condition/single-device/train_t2v_stage3.sh | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh index 630f3c36d6..c678af5e61 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh @@ -49,7 +49,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --dataset_sink_mode False \ + --dataset_sink_mode True\ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index cc0150734c..3a522f63c4 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -50,7 +50,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --dataset_sink_mode False \ + --dataset_sink_mode True\ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index c16fa68c06..f4f455a475 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -49,7 +49,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --dataset_sink_mode False \ + --dataset_sink_mode True\ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh index 6a9921af1f..bc54b94d57 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh @@ -44,7 +44,7 @@ python opensora/train/train_t2v_diffusers.py \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ --max_device_memory "59GB" \ - --dataset_sink_mode False \ + --dataset_sink_mode True\ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh index 30c901ca9d..6ca71cf890 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh @@ -48,7 +48,7 @@ python opensora/train/train_t2v_diffusers.py \ --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --dataset_sink_mode False \ + --dataset_sink_mode True\ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh index fa76ca18c1..45f8f140ae 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh @@ -44,7 +44,7 @@ python opensora/train/train_t2v_diffusers.py \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ --max_device_memory "59GB" \ - --dataset_sink_mode False \ + --dataset_sink_mode True\ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ From 1c602b9c06411adfa1503cd89cc97f6728266ace Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Dec 2024 10:54:36 +0800 Subject: [PATCH 119/133] Revert "set dataset sink mode" This reverts commit 410e689c1c649f0141139f7461483d6983484963. --- .../scripts/text_condition/multi-devices/train_t2i_stage1.sh | 2 +- .../scripts/text_condition/multi-devices/train_t2v_stage2.sh | 2 +- .../scripts/text_condition/multi-devices/train_t2v_stage3.sh | 2 +- .../scripts/text_condition/single-device/train_t2i_stage1.sh | 2 +- .../scripts/text_condition/single-device/train_t2v_stage2.sh | 2 +- .../scripts/text_condition/single-device/train_t2v_stage3.sh | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh index c678af5e61..630f3c36d6 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh @@ -49,7 +49,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --dataset_sink_mode True\ + --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index 3a522f63c4..cc0150734c 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -50,7 +50,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --dataset_sink_mode True\ + --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index f4f455a475..c16fa68c06 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -49,7 +49,7 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 -- --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --dataset_sink_mode True\ + --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh index bc54b94d57..6a9921af1f 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2i_stage1.sh @@ -44,7 +44,7 @@ python opensora/train/train_t2v_diffusers.py \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ --max_device_memory "59GB" \ - --dataset_sink_mode True\ + --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh index 6ca71cf890..30c901ca9d 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage2.sh @@ -48,7 +48,7 @@ python opensora/train/train_t2v_diffusers.py \ --parallel_mode "zero" \ --zero_stage 2 \ --max_device_memory "59GB" \ - --dataset_sink_mode True\ + --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ diff --git a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh index 45f8f140ae..fa76ca18c1 100644 --- a/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/single-device/train_t2v_stage3.sh @@ -44,7 +44,7 @@ python opensora/train/train_t2v_diffusers.py \ --speed_factor 1.0 \ --drop_short_ratio 0.0 \ --max_device_memory "59GB" \ - --dataset_sink_mode True\ + --dataset_sink_mode False \ --prediction_type "v_prediction" \ --hw_stride 32 \ --sparse1d \ From c1aefb5555599b8e8c911eadc809bf4576c9426f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Dec 2024 11:10:07 +0800 Subject: [PATCH 120/133] update Collate usage --- .../opensora/dataset/t2v_datasets.py | 2 ++ .../opensora/train/train_t2v_diffusers.py | 16 ++-------------- .../opensora_pku/opensora/utils/dataset_utils.py | 4 ++-- examples/opensora_pku/tests/test_data.py | 2 +- 4 files changed, 7 insertions(+), 17 deletions(-) diff --git a/examples/opensora_pku/opensora/dataset/t2v_datasets.py b/examples/opensora_pku/opensora/dataset/t2v_datasets.py index 013d4be5b1..5e7f6c81fb 100644 --- a/examples/opensora_pku/opensora/dataset/t2v_datasets.py +++ b/examples/opensora_pku/opensora/dataset/t2v_datasets.py @@ -421,6 +421,8 @@ def define_frame_index(self, data): # shape_idx_dict[pre_define_shape] = [index] # else: # shape_idx_dict[pre_define_shape].append(index) + if len(sample_size) == 0: + raise ValueError("sample_size is empty!") counter = Counter(sample_size) counter_cp = counter if not self.force_resolution and self.max_hxw is not None and self.min_hxw is not None: diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 5c8e71398b..de19f0b3d6 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -314,7 +314,7 @@ def main(args): lengths=train_dataset.lengths, group_data=args.group_data, ) - collate_fn = Collate(args) + collate_fn = Collate(args.train_batch_size, args) dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, @@ -348,19 +348,7 @@ def main(args): group_data=args.group_data, ) - collate_fn = Collate( - args.val_batch_size, - args.group_frame, - args.group_resolution, - args.max_height, - args.max_width, - args.ae_stride, - args.ae_stride_t, - args.patch_size, - args.patch_size_t, - args.num_frames, - args.use_image_num, - ) + collate_fn = Collate(args.val_batch_size, args) val_dataloader = create_dataloader( val_dataset, batch_size=args.val_batch_size, diff --git a/examples/opensora_pku/opensora/utils/dataset_utils.py b/examples/opensora_pku/opensora/utils/dataset_utils.py index 27caeb88eb..59405cbbaa 100644 --- a/examples/opensora_pku/opensora/utils/dataset_utils.py +++ b/examples/opensora_pku/opensora/utils/dataset_utils.py @@ -128,8 +128,8 @@ def pad_to_multiple(number, ds_stride): class Collate: - def __init__(self, args): - self.batch_size = args.train_batch_size + def __init__(self, batch_size, args): + self.batch_size = batch_size self.group_data = args.group_data self.force_resolution = args.force_resolution diff --git a/examples/opensora_pku/tests/test_data.py b/examples/opensora_pku/tests/test_data.py index b7efd0c9d2..b1488c0c59 100644 --- a/examples/opensora_pku/tests/test_data.py +++ b/examples/opensora_pku/tests/test_data.py @@ -66,7 +66,7 @@ def load_dataset_and_dataloader(args, device_num=1, rank_id=0): lengths=train_dataset.lengths, group_data=args.group_data, ) - collate_fn = Collate(args) + collate_fn = Collate(args.train_batch_size, args) dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, From 623de9129d3c210c6aa6ca931b76fc5e4a091724 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:07:09 +0800 Subject: [PATCH 121/133] allow multiple test data shell --- .../opensora/dataset/t2v_datasets.py | 4 ++- examples/opensora_pku/tests/test_data.py | 16 +++++++++-- ...ata.sh => test_data_dynamic_resolution.sh} | 10 ++++++- .../tests/test_data_fix_resolution.sh | 28 +++++++++++++++++++ 4 files changed, 54 insertions(+), 4 deletions(-) rename examples/opensora_pku/tests/{test_data.sh => test_data_dynamic_resolution.sh} (65%) create mode 100644 examples/opensora_pku/tests/test_data_fix_resolution.sh diff --git a/examples/opensora_pku/opensora/dataset/t2v_datasets.py b/examples/opensora_pku/opensora/dataset/t2v_datasets.py index 5e7f6c81fb..0c9b813ff3 100644 --- a/examples/opensora_pku/opensora/dataset/t2v_datasets.py +++ b/examples/opensora_pku/opensora/dataset/t2v_datasets.py @@ -434,7 +434,9 @@ def define_frame_index(self, data): ) len_before_filter_major = len(sample_size) - filter_major_num = 4 * self.total_batch_size + filter_major_num = ( + self.total_batch_size + ) # allow the sample_size with at least `total_batch_size` samples in the dataset new_cap_list, sample_size = zip( *[[i, j] for i, j in zip(new_cap_list, sample_size) if counter[j] >= filter_major_num] ) diff --git a/examples/opensora_pku/tests/test_data.py b/examples/opensora_pku/tests/test_data.py index b1488c0c59..c83501fd25 100644 --- a/examples/opensora_pku/tests/test_data.py +++ b/examples/opensora_pku/tests/test_data.py @@ -12,6 +12,7 @@ from opensora.dataset.loader import create_dataloader from opensora.models.causalvideovae import ae_stride_config from opensora.models.diffusion import Diffusion_models +from opensora.npu_config import npu_config from opensora.train.commons import parse_args from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.message_utils import print_banner @@ -102,7 +103,14 @@ def parse_t2v_train_args(parser): default=True, help="whether to use decord to load videos. If not, use opencv to load videos.", ) - + parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") + parser.add_argument( + "--parallel_mode", + default="data", + type=str, + choices=["data", "optim", "semi", "zero"], + help="parallel mode: data, optim, zero", + ) # text encoder & vae & diffusion model parser.add_argument("--vae_fp32", action="store_true") parser.add_argument("--extra_save_mem", action="store_true") @@ -329,7 +337,11 @@ def test_dataloder(dl): args = parse_args(additional_parse_args=parse_t2v_train_args) if args.resume_from_checkpoint == "True": args.resume_from_checkpoint = True - dataset, dataloader = load_dataset_and_dataloader(args) + save_src_strategy = args.use_parallel and args.parallel_mode != "data" + if args.num_frames == 1 or args.use_image_num != 0: + args.sp_size = 1 + rank_id, device_num = npu_config.set_npu_env(args, strategy_ckpt_save_file=save_src_strategy) + dataset, dataloader = load_dataset_and_dataloader(args, device_num=device_num, rank_id=rank_id) test_dataset(dataset) test_dataloder(dataloader) diff --git a/examples/opensora_pku/tests/test_data.sh b/examples/opensora_pku/tests/test_data_dynamic_resolution.sh similarity index 65% rename from examples/opensora_pku/tests/test_data.sh rename to examples/opensora_pku/tests/test_data_dynamic_resolution.sh index 27e2d6aa8d..5597c798d2 100644 --- a/examples/opensora_pku/tests/test_data.sh +++ b/examples/opensora_pku/tests/test_data_dynamic_resolution.sh @@ -1,3 +1,5 @@ +# export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./parallel_logs/" \ python tests/test_data.py \ --model OpenSoraT2V_v1_3-2B/122 \ --text_encoder_name_1 google/mt5-xxl \ @@ -19,4 +21,10 @@ python tests/test_data.py \ --dataloader_num_workers 8 \ --output_dir="test_data/" \ --model_max_length 512 \ - # --force_resolution \ + --hw_stride 32 \ + --trained_data_global_step 0 \ + --group_data \ + --sp_size 8 \ + --train_sp_batch_size 1 \ + --dataset_sink_mode False \ + # --use_parallel True \ diff --git a/examples/opensora_pku/tests/test_data_fix_resolution.sh b/examples/opensora_pku/tests/test_data_fix_resolution.sh new file mode 100644 index 0000000000..1b537a73a5 --- /dev/null +++ b/examples/opensora_pku/tests/test_data_fix_resolution.sh @@ -0,0 +1,28 @@ +# export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +# msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./parallel_logs/" \ +python tests/test_data.py \ + --model OpenSoraT2V_v1_3-2B/122 \ + --text_encoder_name_1 google/mt5-xxl \ + --dataset t2v \ + --num_frames 93 \ + --data "scripts/train_data/video_data_v1_2.txt" \ + --cache_dir "./" \ + --ae WFVAEModel_D8_4x8x8 \ + --ae_path "LanguageBind/Open-Sora-Plan-v1.3.0/vae" \ + --sample_rate 1 \ + --max_height 352 \ + --max_width 640 \ + --force_resolution \ + --train_fps 16 \ + --interpolation_scale_t 1.0 \ + --interpolation_scale_h 1.0 \ + --interpolation_scale_w 1.0 \ + --train_batch_size=1 \ + --dataloader_num_workers 8 \ + --output_dir="test_data/" \ + --model_max_length 512 \ + --hw_stride 32 \ + --trained_data_global_step 0 \ + --group_data \ + --dataset_sink_mode False \ + # --use_parallel True \ From 87679d16596ff45699961f9631425506bd99747a Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:33:18 +0800 Subject: [PATCH 122/133] use batch sampler not sampler --- .../opensora/train/train_t2v_diffusers.py | 14 +++++----- .../opensora/utils/dataset_utils.py | 27 ++++++++++++------- examples/opensora_pku/tests/test_data.py | 17 ++++-------- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index de19f0b3d6..a0d5ec02ed 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -26,7 +26,7 @@ from opensora.npu_config import npu_config from opensora.train.commons import create_loss_scaler, parse_args from opensora.utils.callbacks import EMAEvalSwapCallback, PerfRecorderCallback -from opensora.utils.dataset_utils import Collate, LengthGroupedSampler +from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler from opensora.utils.ema import EMA from opensora.utils.message_utils import print_banner from opensora.utils.utils import get_precision, save_diffusers_json @@ -306,7 +306,7 @@ def main(args): args.min_hxw = args.max_hxw // 4 train_dataset = getdataset(args, dataset_file=args.data) - sampler = LengthGroupedSampler( + batch_sampler = LengthGroupedBatchSampler( args.train_batch_size, world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), gradient_accumulation_size=args.gradient_accumulation_steps, @@ -318,14 +318,14 @@ def main(args): dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, - shuffle=sampler is None, + shuffle=batch_sampler is None, device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, num_parallel_workers=args.dataloader_num_workers, max_rowsize=args.max_rowsize, prefetch_size=args.dataloader_prefetch_size, collate_fn=collate_fn, - sampler=sampler, + batch_sampler=batch_sampler, column_names=["pixel_values", "attention_mask", "text_embed", "encoder_attention_mask"], ) dataloader_size = dataloader.get_dataset_size() @@ -339,7 +339,7 @@ def main(args): assert os.path.exists(args.val_data), f"validation dataset file must exist, but got {args.val_data}" print_banner("Validation dataset Loading...") val_dataset = getdataset(args, dataset_file=args.val_data) - sampler = LengthGroupedSampler( + batch_sampler = LengthGroupedBatchSampler( args.val_batch_size, world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), lengths=val_dataset.lengths, @@ -352,14 +352,14 @@ def main(args): val_dataloader = create_dataloader( val_dataset, batch_size=args.val_batch_size, - shuffle=sampler is None, + shuffle=batch_sampler is None, device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, num_parallel_workers=args.dataloader_num_workers, max_rowsize=args.max_rowsize, prefetch_size=args.dataloader_prefetch_size, collate_fn=collate_fn, - sampler=sampler, + batch_sampler=batch_sampler, column_names=["pixel_values", "attention_mask", "text_embed", "encoder_attention_mask"], ) val_dataloader_size = val_dataloader.get_dataset_size() diff --git a/examples/opensora_pku/opensora/utils/dataset_utils.py b/examples/opensora_pku/opensora/utils/dataset_utils.py index 59405cbbaa..fc07a0c0d6 100644 --- a/examples/opensora_pku/opensora/utils/dataset_utils.py +++ b/examples/opensora_pku/opensora/utils/dataset_utils.py @@ -389,6 +389,7 @@ def get_length_grouped_indices( generator=None, group_data=False, seed=42, + return_batch_indices=True, ): if generator is None: generator = np.random.default_rng(seed) # every rank will generate a fixed order but random index @@ -458,8 +459,14 @@ def get_length_grouped_indices( # print('\nshuffled_megabatches', shuffled_megabatches) # import ipdb;ipdb.set_trace() # print('\nshuffled_megabatches len', [[i, lengths[i]] for megabatch in shuffled_megabatches for batch in megabatch for i in batch]) - # return [i for megabatch in shuffled_megabatches for batch in megabatch for i in batch] # return epoch indices in a list - return [batch for megabatch in shuffled_megabatches for batch in megabatch] # return batch indices (list of lists) + if not return_batch_indices: + return [ + i for megabatch in shuffled_megabatches for batch in megabatch for i in batch + ] # return epoch indices in a list + else: + return [ + batch for megabatch in shuffled_megabatches for batch in megabatch + ] # return batch indices (list of lists) class LengthGroupedSampler: @@ -506,6 +513,7 @@ def __iter__(self): self.initial_global_step, group_data=self.group_data, generator=self.generator, + return_batch_indices=False, ) return iter(indices) @@ -521,10 +529,9 @@ def __init__( self, batch_size: int, world_size: int, + gradient_accumulation_size: int, lengths: Optional[List[int]] = None, - initial_global_step_for_sampler: int = 0, - group_frame=False, - group_resolution=False, + initial_global_step: int = 0, group_data=False, generator=None, ): @@ -534,9 +541,9 @@ def __init__( self.batch_size = batch_size self.world_size = world_size self.megabatch_size = self.world_size * self.batch_size + self.initial_global_step = initial_global_step + self.gradient_accumulation_size = gradient_accumulation_size self.lengths = lengths - self.group_frame = group_frame - self.group_resolution = group_resolution self.group_data = group_data self.generator = generator self.remainder = len(self) * self.megabatch_size != len(self.lengths) @@ -549,8 +556,10 @@ def __iter__(self): self.lengths, self.batch_size, self.world_size, - group_frame=self.group_frame, - group_resolution=self.group_resolution, + self.gradient_accumulation_size, + self.initial_global_step, + group_data=self.group_data, generator=self.generator, + return_batch_indices=True, ) return iter(indices) diff --git a/examples/opensora_pku/tests/test_data.py b/examples/opensora_pku/tests/test_data.py index c83501fd25..de6c6cd01f 100644 --- a/examples/opensora_pku/tests/test_data.py +++ b/examples/opensora_pku/tests/test_data.py @@ -14,7 +14,7 @@ from opensora.models.diffusion import Diffusion_models from opensora.npu_config import npu_config from opensora.train.commons import parse_args -from opensora.utils.dataset_utils import Collate, LengthGroupedSampler +from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler from opensora.utils.message_utils import print_banner from mindone.utils.config import str2bool @@ -59,7 +59,7 @@ def load_dataset_and_dataloader(args, device_num=1, rank_id=0): args.min_hxw = args.max_hxw // 4 train_dataset = getdataset(args, dataset_file=args.data) - sampler = LengthGroupedSampler( + batch_sampler = LengthGroupedBatchSampler( args.train_batch_size, world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), gradient_accumulation_size=args.gradient_accumulation_steps, @@ -71,14 +71,14 @@ def load_dataset_and_dataloader(args, device_num=1, rank_id=0): dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, - shuffle=sampler is None, + shuffle=batch_sampler is None, device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, num_parallel_workers=args.dataloader_num_workers, max_rowsize=args.max_rowsize, prefetch_size=args.dataloader_prefetch_size, collate_fn=collate_fn, - sampler=sampler, + batch_sampler=batch_sampler, column_names=["pixel_values", "attention_mask", "text_embed", "encoder_attention_mask"], drop_last=True, ) @@ -103,14 +103,7 @@ def parse_t2v_train_args(parser): default=True, help="whether to use decord to load videos. If not, use opencv to load videos.", ) - parser.add_argument("--use_parallel", default=False, type=str2bool, help="use parallel") - parser.add_argument( - "--parallel_mode", - default="data", - type=str, - choices=["data", "optim", "semi", "zero"], - help="parallel mode: data, optim, zero", - ) + # text encoder & vae & diffusion model parser.add_argument("--vae_fp32", action="store_true") parser.add_argument("--extra_save_mem", action="store_true") From 31bdcfd92578a15c8ecd8459e46a20bfe73d95a6 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Dec 2024 16:36:41 +0800 Subject: [PATCH 123/133] use sampler instead of batch sampler --- .../opensora_pku/opensora/dataset/loader.py | 55 +++++++++++------ .../opensora/train/train_t2v_diffusers.py | 14 ++--- .../opensora/utils/dataset_utils.py | 59 +------------------ examples/opensora_pku/tests/test_data.py | 8 +-- 4 files changed, 51 insertions(+), 85 deletions(-) diff --git a/examples/opensora_pku/opensora/dataset/loader.py b/examples/opensora_pku/opensora/dataset/loader.py index d53372a0dc..469ef445e3 100644 --- a/examples/opensora_pku/opensora/dataset/loader.py +++ b/examples/opensora_pku/opensora/dataset/loader.py @@ -75,7 +75,8 @@ def build_dataloader( shuffle=True, drop_last=True, ): - if batch_sampler is None: + if batch_sampler is None and sampler is None: + # use batch sampler if not specified batch_sampler = BatchSampler(datalens, batch_size=batch_size, device_num=device_num, shuffle=shuffle) loader = DataLoader( dataset, @@ -92,7 +93,7 @@ def build_dataloader( class BatchSampler: """ - Batch Sampler + Batch Sampler that return batches of indices instead of single indices """ def __init__(self, lens, batch_size, device_num, shuffle): @@ -134,39 +135,57 @@ def __init__( self.dataset = dataset self.sampler = sampler self.batch_sampler = batch_sampler - self.collat_fn = collate_fn + if self.sampler is not None and self.batch_sampler is not None: + raise ValueError("Cannot specify both a sampler and a batch sampler simultaneously!") + self.collate_fn = collate_fn self.device_num = device_num self.rank_id = rank_id self.drop_last = drop_last self.batch_size = batch_size + self._batch_size = batch_size * device_num def __iter__(self): - self.step_index = 0 - self.batch_indices = iter(self.batch_sampler) - - return self + if self.batch_sampler is not None: + # Use batch_sampler to get batches directly + return iter(self.batch_sampler) + else: + # Use sampler to get indices and create batches + self.sampler_iter = iter(self.sampler) + return self def __next__(self): - indices = next(self.batch_indices) - if len(indices) != self.batch_size and self.drop_last: + if self.batch_sampler is not None: + # Get the next batch directly from the batch sampler + batch_indices = next(self.batch_sampler) + else: + # Get the next indices from the sampler + batch_indices = [next(self.sampler_iter) for _ in range(self._batch_size)] + if len(batch_indices) != self._batch_size and self.drop_last: raise StopIteration() data = [] - per_batch = len(indices) // self.device_num - index = indices[self.rank_id * per_batch : (self.rank_id + 1) * per_batch] + per_batch = len(batch_indices) // self.device_num + index = batch_indices[self.rank_id * per_batch : (self.rank_id + 1) * per_batch] for idx in index: data.append(self.dataset[idx]) - if self.collat_fn is not None: - data = self.collat_fn(data) + if self.collate_fn is not None: + data = self.collate_fn(data) return data def __len__(self): - batch_sampler_len = len(self.batch_sampler) - remainder = self.batch_sampler.remainder - if remainder and self.drop_last: - return batch_sampler_len - 1 + if self.batch_sampler is not None: + batch_sampler_len = len(self.batch_sampler) + remainder = self.batch_sampler.remainder + if remainder and self.drop_last: + return batch_sampler_len - 1 + else: + return batch_sampler_len else: - return batch_sampler_len + remainder = len(self.sampler) % self._batch_size != 0 + if remainder and not self.drop_last: + return len(self.sampler) // self._batch_size + 1 + else: + return len(self.sampler) // self._batch_size class MetaLoader: diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index a0d5ec02ed..de19f0b3d6 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -26,7 +26,7 @@ from opensora.npu_config import npu_config from opensora.train.commons import create_loss_scaler, parse_args from opensora.utils.callbacks import EMAEvalSwapCallback, PerfRecorderCallback -from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler +from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.ema import EMA from opensora.utils.message_utils import print_banner from opensora.utils.utils import get_precision, save_diffusers_json @@ -306,7 +306,7 @@ def main(args): args.min_hxw = args.max_hxw // 4 train_dataset = getdataset(args, dataset_file=args.data) - batch_sampler = LengthGroupedBatchSampler( + sampler = LengthGroupedSampler( args.train_batch_size, world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), gradient_accumulation_size=args.gradient_accumulation_steps, @@ -318,14 +318,14 @@ def main(args): dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, - shuffle=batch_sampler is None, + shuffle=sampler is None, device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, num_parallel_workers=args.dataloader_num_workers, max_rowsize=args.max_rowsize, prefetch_size=args.dataloader_prefetch_size, collate_fn=collate_fn, - batch_sampler=batch_sampler, + sampler=sampler, column_names=["pixel_values", "attention_mask", "text_embed", "encoder_attention_mask"], ) dataloader_size = dataloader.get_dataset_size() @@ -339,7 +339,7 @@ def main(args): assert os.path.exists(args.val_data), f"validation dataset file must exist, but got {args.val_data}" print_banner("Validation dataset Loading...") val_dataset = getdataset(args, dataset_file=args.val_data) - batch_sampler = LengthGroupedBatchSampler( + sampler = LengthGroupedSampler( args.val_batch_size, world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), lengths=val_dataset.lengths, @@ -352,14 +352,14 @@ def main(args): val_dataloader = create_dataloader( val_dataset, batch_size=args.val_batch_size, - shuffle=batch_sampler is None, + shuffle=sampler is None, device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, num_parallel_workers=args.dataloader_num_workers, max_rowsize=args.max_rowsize, prefetch_size=args.dataloader_prefetch_size, collate_fn=collate_fn, - batch_sampler=batch_sampler, + sampler=sampler, column_names=["pixel_values", "attention_mask", "text_embed", "encoder_attention_mask"], ) val_dataloader_size = val_dataloader.get_dataset_size() diff --git a/examples/opensora_pku/opensora/utils/dataset_utils.py b/examples/opensora_pku/opensora/utils/dataset_utils.py index fc07a0c0d6..450e925618 100644 --- a/examples/opensora_pku/opensora/utils/dataset_utils.py +++ b/examples/opensora_pku/opensora/utils/dataset_utils.py @@ -389,7 +389,6 @@ def get_length_grouped_indices( generator=None, group_data=False, seed=42, - return_batch_indices=True, ): if generator is None: generator = np.random.default_rng(seed) # every rank will generate a fixed order but random index @@ -459,14 +458,9 @@ def get_length_grouped_indices( # print('\nshuffled_megabatches', shuffled_megabatches) # import ipdb;ipdb.set_trace() # print('\nshuffled_megabatches len', [[i, lengths[i]] for megabatch in shuffled_megabatches for batch in megabatch for i in batch]) - if not return_batch_indices: - return [ - i for megabatch in shuffled_megabatches for batch in megabatch for i in batch - ] # return epoch indices in a list - else: - return [ - batch for megabatch in shuffled_megabatches for batch in megabatch - ] # return batch indices (list of lists) + return [ + i for megabatch in shuffled_megabatches for batch in megabatch for i in batch + ] # return epoch indices in a single list class LengthGroupedSampler: @@ -513,53 +507,6 @@ def __iter__(self): self.initial_global_step, group_data=self.group_data, generator=self.generator, - return_batch_indices=False, ) return iter(indices) - - -class LengthGroupedBatchSampler: - r""" - Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while - keeping a bit of randomness. - """ - - def __init__( - self, - batch_size: int, - world_size: int, - gradient_accumulation_size: int, - lengths: Optional[List[int]] = None, - initial_global_step: int = 0, - group_data=False, - generator=None, - ): - if lengths is None: - raise ValueError("Lengths must be provided.") - - self.batch_size = batch_size - self.world_size = world_size - self.megabatch_size = self.world_size * self.batch_size - self.initial_global_step = initial_global_step - self.gradient_accumulation_size = gradient_accumulation_size - self.lengths = lengths - self.group_data = group_data - self.generator = generator - self.remainder = len(self) * self.megabatch_size != len(self.lengths) - - def __len__(self): - return len(list(range(0, len(self.lengths), self.megabatch_size))) - - def __iter__(self): - indices = get_length_grouped_indices( - self.lengths, - self.batch_size, - self.world_size, - self.gradient_accumulation_size, - self.initial_global_step, - group_data=self.group_data, - generator=self.generator, - return_batch_indices=True, - ) - return iter(indices) diff --git a/examples/opensora_pku/tests/test_data.py b/examples/opensora_pku/tests/test_data.py index de6c6cd01f..6fffbe34f7 100644 --- a/examples/opensora_pku/tests/test_data.py +++ b/examples/opensora_pku/tests/test_data.py @@ -14,7 +14,7 @@ from opensora.models.diffusion import Diffusion_models from opensora.npu_config import npu_config from opensora.train.commons import parse_args -from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler +from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.message_utils import print_banner from mindone.utils.config import str2bool @@ -59,7 +59,7 @@ def load_dataset_and_dataloader(args, device_num=1, rank_id=0): args.min_hxw = args.max_hxw // 4 train_dataset = getdataset(args, dataset_file=args.data) - batch_sampler = LengthGroupedBatchSampler( + sampler = LengthGroupedSampler( args.train_batch_size, world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), gradient_accumulation_size=args.gradient_accumulation_steps, @@ -71,14 +71,14 @@ def load_dataset_and_dataloader(args, device_num=1, rank_id=0): dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, - shuffle=batch_sampler is None, + shuffle=sampler is None, device_num=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), rank_id=rank_id if not get_sequence_parallel_state() else hccl_info.group_id, num_parallel_workers=args.dataloader_num_workers, max_rowsize=args.max_rowsize, prefetch_size=args.dataloader_prefetch_size, collate_fn=collate_fn, - batch_sampler=batch_sampler, + sampler=sampler, column_names=["pixel_values", "attention_mask", "text_embed", "encoder_attention_mask"], drop_last=True, ) From 13407fff5c62b4d4906b4ff19ed81a948093fb74 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Mon, 23 Dec 2024 17:38:38 +0800 Subject: [PATCH 124/133] caption refiner --- examples/opensora_pku/README.md | 13 ++++++++++++- .../opensora_pku/opensora/sample/caption_refiner.py | 2 +- examples/opensora_pku/opensora/sample/sample.py | 8 ++++---- .../opensora_pku/opensora/utils/sample_utils.py | 2 -- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index d0bf16801b..ff4e06104e 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -56,7 +56,7 @@ Videos are saved to `.gif` for display. - 📍 **Open-Sora-Plan v1.3.0** with the following features - ✅ WFVAE inference & multi-stage training. - ✅ mT5-xxl TextEncoder model inference. - - ✅ Prompt Refiner. + - ✅ Prompt Refiner Inference. - ✅ Text-to-video generation up to 93 frames and 640x640 resolution. - ✅ Multi-stage training using Zero2 and sequence parallelism. - ✅ Acceleration methods: flash attention, recompute (graident checkpointing), mixed precision, data parallelism, etc.. @@ -220,6 +220,17 @@ Please edit the `master_port` to a different port number in the range 1024 to 65 See more examples of multi-device inference scripts under `scripts/text_condifion/multi-devices`. +### Prompt Refiner Inference + +If you want to run T2V inference with caption refiner, you should attach to following argument to the T2V inference command above: +``` + --caption_refiner "LanguageBind/Open-Sora-Plan-v1.3.0/prompt_refiner/" +``` + +If you just want to run prompt refinement, please run: +```bash +python opensora/sample/caption_refiner.py +``` ### Sequence Parallelism We support running inference with sequence parallelism. Please see the `sample_t2v_93x640_sp.sh` under `scripts/text_condition/multi-devices/`. The script will run a 8-card inference with `sp_size=8`, which means each video tensor is sliced into 8 parts along the sequence dimension. If you want to try `sp_size=4`, you can revise it as below: diff --git a/examples/opensora_pku/opensora/sample/caption_refiner.py b/examples/opensora_pku/opensora/sample/caption_refiner.py index a889a90169..649e1602f2 100644 --- a/examples/opensora_pku/opensora/sample/caption_refiner.py +++ b/examples/opensora_pku/opensora/sample/caption_refiner.py @@ -29,7 +29,7 @@ def get_refiner_output(self, prompt): if __name__ == "__main__": - pretrained_model_name_or_path = "" + pretrained_model_name_or_path = "LanguageBind/Open-Sora-Plan-v1.3.0/prompt_refiner/" caption_refiner = OpenSoraCaptionRefiner(pretrained_model_name_or_path, dtype=ms.float16) prompt = "a video of a girl playing in the park" response = caption_refiner.get_refiner_output(prompt) diff --git a/examples/opensora_pku/opensora/sample/sample.py b/examples/opensora_pku/opensora/sample/sample.py index 8e0fd6621a..9bf3121415 100644 --- a/examples/opensora_pku/opensora/sample/sample.py +++ b/examples/opensora_pku/opensora/sample/sample.py @@ -1,6 +1,8 @@ import os import sys +import mindspore as ms + # TODO: remove in future when mindone is ready for install mindone_lib_path = os.path.abspath("../../") sys.path.insert(0, mindone_lib_path) @@ -10,13 +12,11 @@ import time from opensora.npu_config import npu_config +from opensora.sample.caption_refiner import OpenSoraCaptionRefiner from opensora.utils.sample_utils import get_args, prepare_pipeline, run_model_and_save_samples from mindone.utils.logger import set_logger -# from opensora.sample.caption_refiner import OpenSoraCaptionRefiner - - logger = logging.getLogger(__name__) if __name__ == "__main__": @@ -39,7 +39,7 @@ pipeline = prepare_pipeline(args) # build I2V/T2V pipeline - if args.caption_refiner is not None: # TODO: TO TEST + if args.caption_refiner is not None: caption_refiner_model = OpenSoraCaptionRefiner(args.caption_refiner, dtype=ms.float16) else: caption_refiner_model = None diff --git a/examples/opensora_pku/opensora/utils/sample_utils.py b/examples/opensora_pku/opensora/utils/sample_utils.py index ea1cce0826..97dee27041 100644 --- a/examples/opensora_pku/opensora/utils/sample_utils.py +++ b/examples/opensora_pku/opensora/utils/sample_utils.py @@ -10,8 +10,6 @@ from opensora.acceleration.parallel_states import get_sequence_parallel_state, hccl_info from opensora.dataset.text_dataset import create_dataloader from opensora.models.causalvideovae import ae_stride_config, ae_wrapper - -# from opensora.sample.caption_refiner import OpenSoraCaptionRefiner from opensora.models.diffusion.common import PatchEmbed2D from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V_v1_3 from opensora.models.diffusion.opensora.modules import Attention, LayerNorm From 733a44621499a19646c94e6588cce4cf5a98ae93 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 24 Dec 2024 09:35:30 +0800 Subject: [PATCH 125/133] fix PR run error --- .../model/losses/discriminator.py | 4 +- .../models/text_encoder/t5_encoder.py | 1 + examples/opensora_pku/tests/test_wavelet.py | 1 + .../readme_load_states.md | 210 ------------------ 4 files changed, 4 insertions(+), 212 deletions(-) delete mode 100644 examples/opensora_pku/torch_intermediate_states/readme_load_states.md diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py index 8c83c9d4af..252a2b8a3d 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/losses/discriminator.py @@ -65,7 +65,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f nf_mult_prev = nf_mult nf_mult = min(2**n, 8) sequence += [ - Conv3d( + nn.Conv3d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), @@ -84,7 +84,7 @@ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, dtype=ms.f nf_mult_prev = nf_mult nf_mult = min(2**n_layers, 8) sequence += [ - Conv3d( + nn.Conv3d( ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), diff --git a/examples/opensora_pku/opensora/models/text_encoder/t5_encoder.py b/examples/opensora_pku/opensora/models/text_encoder/t5_encoder.py index 33fbb2126f..2dff783efd 100644 --- a/examples/opensora_pku/opensora/models/text_encoder/t5_encoder.py +++ b/examples/opensora_pku/opensora/models/text_encoder/t5_encoder.py @@ -7,6 +7,7 @@ import mindspore as ms import mindspore.nn as nn +import mindspore.ops as ops from mindone.transformers.activations import ACT2FN from mindone.transformers.modeling_utils import MSPreTrainedModel as PreTrainedModel diff --git a/examples/opensora_pku/tests/test_wavelet.py b/examples/opensora_pku/tests/test_wavelet.py index 2afa5fe75c..438ba4f636 100644 --- a/examples/opensora_pku/tests/test_wavelet.py +++ b/examples/opensora_pku/tests/test_wavelet.py @@ -15,6 +15,7 @@ InverseHaarWaveletTransform2D, InverseHaarWaveletTransform3D, ) + from tests.torch_wavelet import HaarWaveletTransform2D as HaarWaveletTransform2D_torch from tests.torch_wavelet import HaarWaveletTransform3D as HaarWaveletTransform3D_torch from tests.torch_wavelet import InverseHaarWaveletTransform2D as InverseHaarWaveletTransform2D_torch diff --git a/examples/opensora_pku/torch_intermediate_states/readme_load_states.md b/examples/opensora_pku/torch_intermediate_states/readme_load_states.md deleted file mode 100644 index 89d923b752..0000000000 --- a/examples/opensora_pku/torch_intermediate_states/readme_load_states.md +++ /dev/null @@ -1,210 +0,0 @@ -### Updated files -- opensora/utils/sample.utils.py (NEW) -- opensora/sample/sample.py <- opensora/sample/sample_t2v.py -- opensora/sample/pipeline_opensora.py - -- opensora/models/diffusion/common.py <- opensora/models/diffusion/opensora/rope.py -- opensora/models/diffusion/opensora/modeling_opensora.py -- opensora/models/diffusion/opensora/modules.py - -### Debugging script -scripts/text_condition/single-device/sample_debug.sh - -### Intermediate dicts to load -Details of saving intermediate states in Pytorch Version opensora/models/diffusion/modeling_opensora.py forward(): - -Note: I only save them in first step of denoising. - -```python -def forward( - self, - hidden_states: torch.Tensor, - timestep: Optional[torch.LongTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - **kwargs, - ): - ##################################### - ## !!!SAVE `input parameters - np.save("./hidden_states_input.npy", hidden_states.float().cpu().numpy()) - np.save("./timestep_input.npy", timestep.float().cpu().numpy()) - np.save("./encoder_hidden_states_input.npy", encoder_hidden_states.float().cpu().numpy()) - np.save("./attention_mask_input.npy", attention_mask.float().cpu().numpy()) - np.save("./encoder_attention_mask_input.npy", encoder_attention_mask.float().cpu().numpy()) - ##################################### - - batch_size, c, frame, h, w = hidden_states.shape - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. - # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. - # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None and attention_mask.ndim == 4: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - # b, frame, h, w -> a video - # b, 1, h, w -> only images - attention_mask = attention_mask.to(self.dtype) - - attention_mask = attention_mask.unsqueeze(1) # b 1 t h w - attention_mask = F.max_pool3d( - attention_mask, - kernel_size=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size), - stride=(self.config.patch_size_t, self.config.patch_size, self.config.patch_size) - ) - attention_mask = rearrange(attention_mask, 'b 1 t h w -> (b 1) 1 (t h w)') - attention_mask = (1 - attention_mask.bool().to(self.dtype)) * -10000.0 - - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: - # b, 1, l - encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 - - ##################################### - ## !!! SAVE `masks` after conversion - # Note: they used "0" as True, "-10000" as False, and masks have transposed dimension - # I do not suggest to load these masks in Mindspore version - np.save("./attention_mask_converted.npy", attention_mask.float().cpu().numpy()) - np.save("./encoder_attention_mask_converted.npy", encoder_attention_mask.float().cpu().numpy()) - ##################################### - - # 1. Input - frame = ((frame - 1) // self.config.patch_size_t + 1) if frame % 2 == 1 else frame // self.config.patch_size_t # patchfy - height, width = hidden_states.shape[-2] // self.config.patch_size, hidden_states.shape[-1] // self.config.patch_size - - hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( - hidden_states, encoder_hidden_states, timestep, batch_size, frame - ) - - ##################################### - ## !!! SAVE states after `_operate_on_patched_inputs` - np.save("./hidden_states_operate_on_patched_inputs.npy", hidden_states.float().cpu().numpy()) - np.save("./encoder_hidden_states_operate_on_patched_inputs.npy", encoder_hidden_states.float().cpu().numpy()) - np.save("./timestep_operate_on_patched_inputs.npy", timestep.float().cpu().numpy()) - np.save("./embedded_timestep_operate_on_patched_inputs.npy", embedded_timestep.float().cpu().numpy()) - ##################################### - - # To - # x (t*h*w b d) or (t//sp*h*w b d) - # cond_1 (l b d) or (l//sp b d) - hidden_states = rearrange(hidden_states, 'b s h -> s b h', b=batch_size).contiguous() - encoder_hidden_states = rearrange(encoder_hidden_states, 'b s h -> s b h', b=batch_size).contiguous() - timestep = timestep.view(batch_size, 6, -1).transpose(0, 1).contiguous() - - sparse_mask = {} - if npu_config is None: - if get_sequence_parallel_state(): - head_num = self.config.num_attention_heads // nccl_info.world_size - else: - head_num = self.config.num_attention_heads - else: - head_num = None - for sparse_n in [1, 4]: - sparse_mask[sparse_n] = Attention.prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num) - ##################################### - ## !!! SAVE sparse masks - # Note: they used "0" as True, "-10000" as False, and masks have transposed dimension - # I do not suggest to load these masks in Mindspore version - attention_mask_sparse_1_False, encoder_attention_mask_sparse_1_False = sparse_mask[1][False] # mask_sparse_1d - attention_mask_sparse_1_True, encoder_attention_mask_sparse_1_True = sparse_mask[1][True] # mask_sparse_1d_group - attention_mask_sparse_4_False, encoder_attention_mask_sparse_4_False = sparse_mask[4][False] # mask_sparse_1d - attention_mask_sparse_4_True, encoder_attention_mask_sparse_4_True = sparse_mask[4][True] # sparse_1d_group - np.save("./attention_mask_sparse_1_False.npy", attention_mask_sparse_1_False.float().cpu().numpy()) - np.save("./encoder_attention_mask_sparse_1_False.npy", encoder_attention_mask_sparse_1_False.float().cpu().numpy()) - np.save("./attention_mask_sparse_1_True.npy", attention_mask_sparse_1_True.float().cpu().numpy()) - np.save("./encoder_attention_mask_sparse_1_True.npy", encoder_attention_mask_sparse_1_True.float().cpu().numpy()) - np.save("./attention_mask_sparse_4_False.npy", attention_mask_sparse_4_False.float().cpu().numpy()) - np.save("./encoder_attention_mask_sparse_4_False.npy", encoder_attention_mask_sparse_4_False.float().cpu().numpy()) - np.save("./attention_mask_sparse_4_True.npy", attention_mask_sparse_4_True.float().cpu().numpy()) - np.save("./encoder_attention_mask_sparse_4_True.npy", encoder_attention_mask_sparse_4_True.float().cpu().numpy()) - ##################################### - - - # 2. Blocks - ##################################### - # !!! SAVE initial input states - np.save(f"./hidden_states_before_block.npy", hidden_states.float().cpu().numpy()) - ##################################### - for i, block in enumerate(self.transformer_blocks): - if i > 1 and i < 30: - attention_mask, encoder_attention_mask = sparse_mask[block.attn1.processor.sparse_n][block.attn1.processor.sparse_group] - else: - attention_mask, encoder_attention_mask = sparse_mask[1][block.attn1.processor.sparse_group] - - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - timestep, - frame, - height, - width, - **ckpt_kwargs, - ) - else: - hidden_states = block( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - timestep=timestep, - frame=frame, - height=height, - width=width, - ) - ##################################### - # !!! SAVE updated states - np.save(f"./hidden_states_{i}_block.npy", hidden_states.float().cpu().numpy()) - ##################################### - - - # To (b, t*h*w, h) or (b, t//sp*h*w, h) - hidden_states = rearrange(hidden_states, 's b h -> b s h', b=batch_size).contiguous() - - # 3. Output - output = self._get_output_for_patched_inputs( - hidden_states=hidden_states, - timestep=timestep, - embedded_timestep=embedded_timestep, - num_frames=frame, - height=height, - width=width, - ) # b c t h w - - ##################################### - #!!! SAVE output hidden states - np.save("./hidden_states_output.npy", output.float().cpu().numpy()) - ##################################### - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) - -``` \ No newline at end of file From 626e84251ba456989531c88fd07a963f59801163 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 24 Dec 2024 09:45:40 +0800 Subject: [PATCH 126/133] fix PR run error --- .../opensora/models/diffusion/opensora/modules.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py index 422fa11d69..0a34b11238 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modules.py @@ -59,11 +59,11 @@ def __init__(self, interpolation_scale_thw, sparse1d, sparse_n, sparse_group, is def prepare_sparse_mask(attention_mask, encoder_attention_mask, sparse_n, head_num): attention_mask = attention_mask.unsqueeze(1) encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - l = attention_mask.shape[-1] - if l % (sparse_n * sparse_n) == 0: + length = attention_mask.shape[-1] + if length % (sparse_n * sparse_n) == 0: pad_len = 0 else: - pad_len = sparse_n * sparse_n - l % (sparse_n * sparse_n) + pad_len = sparse_n * sparse_n - length % (sparse_n * sparse_n) attention_mask_sparse = mint.nn.functional.pad( attention_mask, (0, pad_len, 0, 0), mode="constant", value=0 @@ -199,11 +199,11 @@ def _sparse_1d(self, x, frame, height, width): x: shape if sparse_group: (S//sparse_n, sparse_n*B, D), else: (S//sparse_n, sparse_n*B, D) pad_len: 0 or padding """ - l = x.shape[0] - assert l == frame * height * width + length = x.shape[0] + assert length == frame * height * width pad_len = 0 - if l % (self.sparse_n * self.sparse_n) != 0: - pad_len = self.sparse_n * self.sparse_n - l % (self.sparse_n * self.sparse_n) + if length % (self.sparse_n * self.sparse_n) != 0: + pad_len = self.sparse_n * self.sparse_n - length % (self.sparse_n * self.sparse_n) if pad_len != 0: x = mint.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_len), mode="constant", value=0.0) From 1cdeb9fccf2a80c11ff0d1d26084efb527a0621e Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 24 Dec 2024 10:19:08 +0800 Subject: [PATCH 127/133] mindnlp commit --- examples/opensora_pku/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/opensora_pku/requirements.txt b/examples/opensora_pku/requirements.txt index 987beb7b4f..367ba4c852 100644 --- a/examples/opensora_pku/requirements.txt +++ b/examples/opensora_pku/requirements.txt @@ -9,6 +9,7 @@ imagesize toolz tqdm mindcv +mindnlp safetensors omegaconf pyyaml From 54254703a25b6b713e98f0ac4627dad0b8638b96 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 24 Dec 2024 11:07:34 +0800 Subject: [PATCH 128/133] plot loss curve --- examples/opensora_pku/tools/plot.py | 63 +++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 examples/opensora_pku/tools/plot.py diff --git a/examples/opensora_pku/tools/plot.py b/examples/opensora_pku/tools/plot.py new file mode 100644 index 0000000000..9f924c1abb --- /dev/null +++ b/examples/opensora_pku/tools/plot.py @@ -0,0 +1,63 @@ +""" +Usage: +python tools/plot.py --input path/to/exp1/result.log path/to/exp2/result.log \ + --smooth --alpha 0.001 --y_max 0.6 --output loss_cmp.png +""" + + +import argparse + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import pandas as pd + + +def plot(inp, output, smooth=False, alpha=0.01, interval=1, duration=-1, linewidth=1, y_max=None): + num_curve = len(inp) + plt.figure() + plt.title("loss") + f, ax = plt.subplots(figsize=(14, 6)) + if y_max is not None: + ax.set_ylim([0, y_max]) + + for i in range(num_curve): + log_path = inp[i] + + df = pd.read_csv(log_path, sep="\t") + if smooth: + print("curve soomthing enabled with alpha=", alpha) + loss = df["loss"].ewm(alpha=alpha).mean().values + else: + loss = df["loss"].values + step = df["step"].values + + ax.plot(step[:duration:interval], loss[:duration:interval], label=f"loss_{log_path}", linewidth=linewidth) + + ax.set_xlabel("steps") + ax.set_ylabel("loss") + ax.legend() + ax.grid() + + plt.savefig(output) + print("Figure saved in ", output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input", "-i", type=str, nargs="+", default=None, help="list of path to result log") + parser.add_argument( + "--output", "-o", type=str, default="loss_curve.png", help="target file path to save the output loss curve" + ) + parser.add_argument( + "--smooth", action="store_true", help="smooth curve by exponential weighted (ema). default: False" + ) + parser.add_argument( + "--alpha", default=0.01, type=float, help="smooth factor alpha, the smaller this value, the smoother the curve" + ) + parser.add_argument("--linewidth", default=1.0, type=float, help="curve line width") + parser.add_argument("--y_max", default=None, type=float, help="y max value") + args = parser.parse_args() + + plot(args.input, args.output, args.smooth, args.alpha, linewidth=args.linewidth, y_max=args.y_max) From f2e01545f53d25f7efed7dae7416f19e3cfc9283 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 24 Dec 2024 11:59:46 +0800 Subject: [PATCH 129/133] remove resume_download --- .../models/causalvideovae/model/vae/modeling_causalvae.py | 4 ---- .../models/causalvideovae/model/vae/modeling_wfvae.py | 4 ---- .../opensora/models/diffusion/opensora/modeling_opensora.py | 4 ---- 3 files changed, 12 deletions(-) diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_causalvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_causalvae.py index bc38c0f975..4a2238b215 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_causalvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_causalvae.py @@ -169,7 +169,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", None) @@ -201,7 +200,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, @@ -224,7 +222,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, @@ -243,7 +240,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, diff --git a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py index c43b878dee..e0ad4b698d 100644 --- a/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py +++ b/examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py @@ -648,7 +648,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", None) @@ -681,7 +680,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, @@ -704,7 +702,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, @@ -723,7 +720,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, diff --git a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py index 7a0ba7bffd..7d429089d2 100644 --- a/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/examples/opensora_pku/opensora/models/diffusion/opensora/modeling_opensora.py @@ -154,7 +154,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", None) @@ -186,7 +185,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, @@ -209,7 +207,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, @@ -228,7 +225,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, From deb5fd123e6d696e401d1d654b47192e972b109f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:00:44 +0800 Subject: [PATCH 130/133] load captioner --- examples/opensora_pku/opensora/sample/sample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/opensora_pku/opensora/sample/sample.py b/examples/opensora_pku/opensora/sample/sample.py index 9bf3121415..8b15be600c 100644 --- a/examples/opensora_pku/opensora/sample/sample.py +++ b/examples/opensora_pku/opensora/sample/sample.py @@ -12,7 +12,6 @@ import time from opensora.npu_config import npu_config -from opensora.sample.caption_refiner import OpenSoraCaptionRefiner from opensora.utils.sample_utils import get_args, prepare_pipeline, run_model_and_save_samples from mindone.utils.logger import set_logger @@ -40,6 +39,8 @@ pipeline = prepare_pipeline(args) # build I2V/T2V pipeline if args.caption_refiner is not None: + from opensora.sample.caption_refiner import OpenSoraCaptionRefiner + caption_refiner_model = OpenSoraCaptionRefiner(args.caption_refiner, dtype=ms.float16) else: caption_refiner_model = None From 3452269928b3c933ec5b7127ad9c2f54baa3b874 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:41:05 +0800 Subject: [PATCH 131/133] merge multiple safetensors --- examples/opensora_pku/README.md | 5 ++ .../tools/ckpt/merge_safetensors.py | 68 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 examples/opensora_pku/tools/ckpt/merge_safetensors.py diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index ff4e06104e..84aa903064 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -151,6 +151,11 @@ python tools/model_conversion/convert_wfvae.py --src LanguageBind/Open-Sora-Plan python tools/model_conversion/convert_pytorch_ckpt_to_safetensors.py --src google/mt5-xxl/pytorch_model.bin --target google/mt5-xxl/model.safetensors --config google/mt5-xxl/config.json ``` +In addition, please merge the multiple .saftensors files under `any93x640x640/` into a merged checkpoint: +```shell +python tools/ckpt/merge_safetensors.py -i LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640/ -o LanguageBind/Open-Sora-Plan-v1.3.0/diffusion_pytorch_model.safetensors -f LanguageBind/Open-Sora-Plan-v1.3.0/any93x640x640/diffusion_pytorch_model.safetensors.index.json +``` + Once the checkpoint files have all been prepared, you can refer to the inference guidance below. ## Inference diff --git a/examples/opensora_pku/tools/ckpt/merge_safetensors.py b/examples/opensora_pku/tools/ckpt/merge_safetensors.py new file mode 100644 index 0000000000..b995248e14 --- /dev/null +++ b/examples/opensora_pku/tools/ckpt/merge_safetensors.py @@ -0,0 +1,68 @@ +import argparse +import json +import os + +from safetensors import safe_open +from safetensors.torch import save_file + + +def load_index_file(index_file): + with open(index_file, "r") as f: + return json.load(f) + + +def _load_huggingface_safetensor(ckpt_file): + db_state_dict = {} + with safe_open(ckpt_file, framework="pt", device="cpu") as f: + for key in f.keys(): + db_state_dict[key] = f.get_tensor(key) + return db_state_dict + + +def merge_safetensors(input_folder, index_file, output_file): + # Load the index file + index_data = load_index_file(index_file) + # Iterate through the files specified in the index + weight_map = index_data.get("weight_map", {}) + weight_names = [] + file_paths = [] + for weight_name in weight_map.keys(): + file_paths.append(weight_map[weight_name]) + weight_names.append(weight_name) + file_paths = set(file_paths) + weight_names = set(weight_names) + + sd = [] + for file_path in file_paths: + if file_path: + file_path = os.path.join(input_folder, file_path) + partial_sd = _load_huggingface_safetensor(file_path) + sd.append(partial_sd) + + # Merge all tensors together + merged_tensor = sd[0] + for tensor in sd[1:]: + merged_tensor.update(tensor) + + # Save the merged tensor to a new Safetensor file + save_file(merged_tensor, output_file) + print(f"Merged Safetensors saved as: {output_file}") + + +def main(): + # Set up argument parsing + parser = argparse.ArgumentParser(description="Merge multiple Safetensors files into one using an index.") + parser.add_argument("--input_folder", "-i", type=str, help="Path to the folder containing Safetensors files.") + parser.add_argument("--index_file", "-f", type=str, help="Path to the index JSON file.") + parser.add_argument("--output_file", "-o", type=str, help="Path to the output merged Safetensors file.") + + # Parse the arguments + args = parser.parse_args() + + # Call the merge function + assert args.output_file.endswith(".safetensors") + merge_safetensors(args.input_folder, args.index_file, args.output_file) + + +if __name__ == "__main__": + main() From b7c980ef396558ab58b1210cfd08d6e026abd3f7 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Thu, 2 Jan 2025 15:59:44 +0800 Subject: [PATCH 132/133] update readme --- examples/opensora_pku/README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/opensora_pku/README.md b/examples/opensora_pku/README.md index 84aa903064..cffa08794c 100644 --- a/examples/opensora_pku/README.md +++ b/examples/opensora_pku/README.md @@ -227,7 +227,7 @@ See more examples of multi-device inference scripts under `scripts/text_condifio ### Prompt Refiner Inference -If you want to run T2V inference with caption refiner, you should attach to following argument to the T2V inference command above: +If you want to run T2V inference with caption refiner, you should attach following argument to the T2V inference command above: ``` --caption_refiner "LanguageBind/Open-Sora-Plan-v1.3.0/prompt_refiner/" ``` @@ -502,7 +502,7 @@ See `train_t2v_stage2.sh` under `scripts/text_condition/mult-devices/` for detai #### Performance -We evaluated the training performance on Ascend NPUs. All experiments are running in PYNATIVE mode. The results are as follows. +We evaluated the training performance on Ascend NPUs. All experiments are running in PYNATIVE mode with MindSpore(2.3.1). The results are as follows. | model name | cards | stage | batch size (global) | video size | Paramllelism |recompute |data sink | jit level| step time | train imgs/s | |:----------------|:----------- |:---------:|:-----:|:----------:|:----------:|:----------:|:----------:|:----------:|-------------------:|:----------:| @@ -514,6 +514,8 @@ We evaluated the training performance on Ascend NPUs. All experiments are runnin > *: dynamic resolution using bucket sampler. The step time may vary across different batches due to the varied resolutions. +> train imgs/s: it is computed by $num\quad of\quad frames \times global\quad batch\quad size \div per\quad step\quad time$ + ## 👍 Acknowledgement * [Latte](https://github.com/Vchitect/Latte): The **main codebase** we built upon and it is an wonderful video generated model. * [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis. From df35b1d23f0f5f2d6ca95bc775375784ab958ab0 Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:02:31 +0800 Subject: [PATCH 133/133] export DEVICES --- .../scripts/text_condition/multi-devices/train_t2i_stage1.sh | 2 +- .../scripts/text_condition/multi-devices/train_t2v_stage2.sh | 2 +- .../scripts/text_condition/multi-devices/train_t2v_stage3.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh index 630f3c36d6..cc1c409f0f 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2i_stage1.sh @@ -2,7 +2,7 @@ NUM_FRAME=1 WIDTH=256 HEIGHT=256 -ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ opensora/train/train_t2v_diffusers.py \ --model OpenSoraT2V_v1_3-2B/122 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh index cc0150734c..0ebe09a216 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage2.sh @@ -3,7 +3,7 @@ NUM_FRAME=93 WIDTH=640 HEIGHT=640 MAX_HxW=409600 -ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ opensora/train/train_t2v_diffusers.py \ --model OpenSoraT2V_v1_3-2B/122 \ diff --git a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh index c16fa68c06..b0a369abeb 100644 --- a/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh +++ b/examples/opensora_pku/scripts/text_condition/multi-devices/train_t2v_stage3.sh @@ -2,7 +2,7 @@ NUM_FRAME=93 WIDTH=640 HEIGHT=352 -ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="t2v-video-${NUM_FRAME}x${HEIGHT}x${WIDTH}/parallel_logs" \ opensora/train/train_t2v_diffusers.py \ --model OpenSoraT2V_v1_3-2B/122 \