Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Open-Sora-Plan v1.3.0]: inference and training #717

Open
wants to merge 133 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 131 commits
Commits
Show all changes
133 commits
Select commit Hold shift + click to select a range
16cdd9b
op-v0.3 init
wtomin Oct 29, 2024
bf364a6
add dit, t2v inference
chenyingshu Oct 29, 2024
164bb08
align structure
chenyingshu Nov 4, 2024
be4e1d2
improve acc of PatchEmbed2D
chenyingshu Nov 4, 2024
62f175d
align structure, fix bugs
chenyingshu Nov 6, 2024
c5110a1
update
chenyingshu Nov 14, 2024
bde0718
training
chenyingshu Nov 14, 2024
fce6c11
Create video_data_v1_2.txt
chenyingshu Nov 14, 2024
8cdf20e
Create train_debug.sh
chenyingshu Nov 14, 2024
1a51b8d
Create train_debug.sh
chenyingshu Nov 14, 2024
bc42b7a
fix bug
chenyingshu Nov 14, 2024
562fefa
fix bug for recompute and sp
wtomin Nov 19, 2024
53336e1
update training scripts
wtomin Nov 19, 2024
870e2b2
fix syntax error for graph mode
wtomin Nov 21, 2024
36b561a
update import path & fix ci error
wtomin Nov 21, 2024
737854d
update sample script
wtomin Nov 21, 2024
fc68b91
update train scripts
wtomin Nov 21, 2024
fdf374d
fix error
wtomin Nov 21, 2024
43815c1
a valid file name
wtomin Nov 21, 2024
9c5c3cf
lazy_inline & update save video func
wtomin Nov 21, 2024
3124546
use cv2 to save videos
wtomin Nov 21, 2024
eec25f1
update requirements
wtomin Nov 21, 2024
b106f9f
save path correction
wtomin Nov 21, 2024
4260c22
revise generator with loss vae
wtomin Oct 30, 2024
6ce33e0
update vae training
wtomin Nov 5, 2024
2bf6cdb
support recompute
wtomin Nov 5, 2024
f343d90
correct loss scaler print
wtomin Nov 5, 2024
7ab7661
fix print loss scale
wtomin Nov 5, 2024
58607a9
update eval
wtomin Nov 6, 2024
c79c621
copy to vae/eval
wtomin Nov 6, 2024
727ccd0
a new vae config file
wtomin Nov 6, 2024
2ab6697
remove prefix
wtomin Nov 6, 2024
599e57e
allow model config
wtomin Nov 6, 2024
fc080ba
update 4dim json
wtomin Nov 6, 2024
b13b35a
save diffusers config
wtomin Nov 6, 2024
3835bfc
fix error
wtomin Nov 7, 2024
67c8f28
conv dtype bf16
wtomin Nov 7, 2024
5642d1a
conv 3d tranpose to fp16
wtomin Nov 7, 2024
6f543d1
init loss scale 65536
wtomin Nov 7, 2024
16efbb8
updates
wtomin Nov 7, 2024
5b8bf94
conv2d transpose use fp16
wtomin Nov 7, 2024
65bbda2
change conv2d initializer
wtomin Nov 11, 2024
c54809a
conv3d initializer
wtomin Nov 11, 2024
ce3d314
correct to fan_int
wtomin Nov 11, 2024
22ed936
print rec loss and p_loss & use fp32 for p_loss
wtomin Nov 12, 2024
277b716
updates for graph mode
wtomin Nov 12, 2024
9f5bf88
use ops.standard_normal
wtomin Nov 13, 2024
51d695c
default to print losses
wtomin Nov 13, 2024
60472fb
use max_grad_norm 0.1 for vae training
wtomin Nov 13, 2024
af81883
discriminator bf16 conv3d
wtomin Nov 13, 2024
ecdb0d1
set gen and disc weight decay
wtomin Nov 14, 2024
bac8c79
change training steps order: one loss per step
wtomin Nov 14, 2024
5f4a2ff
update printing log
wtomin Nov 14, 2024
23598df
set norm_dtype bf16
wtomin Nov 21, 2024
27bbd07
correct save video bug
wtomin Nov 22, 2024
e63bc10
print ops data types
wtomin Nov 22, 2024
1a5e9eb
resize nearest neighbor to npu_config.run
wtomin Nov 22, 2024
e40bcb6
print ops info
wtomin Nov 22, 2024
7606fc9
update dataset file
wtomin Nov 22, 2024
4c93b49
update train t2v script
wtomin Nov 22, 2024
5b06226
replace ops.cat with mint.cat
wtomin Nov 22, 2024
0623ebc
remove conflicting requirements
wtomin Nov 25, 2024
1ded897
connect_res_layer_num = 2
wtomin Nov 26, 2024
3b24f40
add new single-device training script
wtomin Nov 26, 2024
219b66e
no_grad in train pipelines
wtomin Nov 27, 2024
972c630
remove vae auto-mixed precision
wtomin Nov 27, 2024
227f2ec
allow memory_profile
wtomin Nov 27, 2024
abcab7a
update rec_video_folder
wtomin Nov 27, 2024
73da943
dynamic inputs
wtomin Nov 27, 2024
21ed520
update profile_memory
wtomin Nov 27, 2024
c14659c
no jit for dit
wtomin Nov 27, 2024
cefb859
sequence length can be not divisible by 16
wtomin Nov 28, 2024
d629441
ops.AlltoAll not support bf16
chenyingshu Dec 6, 2024
634f6c0
dyn resize
chenyingshu Dec 16, 2024
fa6f040
update scripts
chenyingshu Dec 16, 2024
91e075c
update script
chenyingshu Dec 16, 2024
72f9e9c
update causal cache
wtomin Nov 30, 2024
28fc60c
remove vae amp
wtomin Nov 30, 2024
c8cb23d
remove vae ms_checkpoint in sample.py
wtomin Nov 30, 2024
7b2772c
save config.json in train.py
wtomin Nov 30, 2024
1e3c33d
save_ema_only is False
wtomin Nov 30, 2024
a9786af
update npu_config.init_env
wtomin Nov 30, 2024
3542a62
ema is turned off by default
wtomin Nov 30, 2024
7a25659
update sample & train scripts
wtomin Nov 30, 2024
03f68a6
update train script name
wtomin Nov 30, 2024
a2f1ad9
tile_overlap_factor remove
wtomin Nov 30, 2024
fbdeadc
remove config.recompute
wtomin Dec 3, 2024
a2abf1d
load ms_checkpoint filter
wtomin Dec 3, 2024
c08d314
fix import error
wtomin Dec 3, 2024
47da9b4
fix import error
wtomin Dec 3, 2024
b337dc7
remove vae use_recompute
wtomin Dec 3, 2024
a6e9a46
allow sparse1d is False
wtomin Dec 4, 2024
5ad9b45
allow wavelet loss
wtomin Dec 4, 2024
2ac4c0d
print wavelet loss
wtomin Dec 4, 2024
c982edb
fix disc no_grad
wtomin Dec 4, 2024
d37ed03
fix vae tile decode error
wtomin Dec 5, 2024
61468fa
edit attention_mask shape
wtomin Dec 5, 2024
3f325d9
allow set sparse_n other than 4
wtomin Dec 6, 2024
3f17017
remove redudant args & allow rectified flow in training
wtomin Dec 11, 2024
93b4513
allow different length
wtomin Dec 12, 2024
ca4da34
allow parallel video reconstruction
wtomin Dec 12, 2024
5a183a8
print loss weight item and resume from log file
wtomin Dec 12, 2024
2aa786e
impr logging
wtomin Dec 12, 2024
e3dc5e6
readme update to vae training
wtomin Dec 16, 2024
fa850d0
update t2v training
wtomin Dec 16, 2024
5573d60
update sample shape to 352x640
wtomin Dec 16, 2024
b702e54
update readmd
wtomin Dec 16, 2024
33fbd4d
disable enable_tiling by default
wtomin Dec 16, 2024
8d01c3d
fix sp inference error
wtomin Dec 17, 2024
04b3087
1x256x256 exp script and performance update
wtomin Dec 17, 2024
618d4d2
update other stages performance table
wtomin Dec 17, 2024
54efbe3
rewrite test_data
wtomin Dec 17, 2024
a97395b
remove _backbone from dit ckpt
wtomin Dec 17, 2024
7af2562
update demo
wtomin Dec 18, 2024
0ffab7f
fix sp inference error
wtomin Dec 19, 2024
f09d610
make block input as SBH
wtomin Dec 20, 2024
9c72d71
update new dataset v0.3.0
wtomin Dec 20, 2024
48b9c82
set dataset sink mode
wtomin Dec 23, 2024
1c602b9
Revert "set dataset sink mode"
wtomin Dec 23, 2024
c1aefb5
update Collate usage
wtomin Dec 23, 2024
623de91
allow multiple test data shell
wtomin Dec 23, 2024
87679d1
use batch sampler not sampler
wtomin Dec 23, 2024
31bdcfd
use sampler instead of batch sampler
wtomin Dec 23, 2024
13407ff
caption refiner
wtomin Dec 23, 2024
733a446
fix PR run error
wtomin Dec 24, 2024
626e842
fix PR run error
wtomin Dec 24, 2024
1cdeb9f
mindnlp commit
wtomin Dec 24, 2024
5425470
plot loss curve
wtomin Dec 24, 2024
f2e0154
remove resume_download
wtomin Dec 24, 2024
deb5fd1
load captioner
wtomin Dec 24, 2024
3452269
merge multiple safetensors
wtomin Dec 24, 2024
b7c980e
update readme
wtomin Jan 2, 2025
df35b1d
export DEVICES
wtomin Jan 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
347 changes: 153 additions & 194 deletions examples/opensora_pku/README.md

Large diffs are not rendered by default.

88 changes: 42 additions & 46 deletions examples/opensora_pku/examples/rec_image.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,27 @@
"""
Run causal vae reconstruction on a given image
Usage example:
python examples/rec_image.py \
--ae_path LanguageBind/Open-Sora-Plan-v1.2.0/vae \
--image_path test.jpg \
--rec_path rec.jpg \
--device Ascend \
--short_size 512 \
--enable_tiling
"""
import argparse
import logging
import os
import re
import sys

import cv2
import numpy as np
from albumentations import Compose, Lambda, Resize, ToFloat
from PIL import Image

import mindspore as ms
from mindspore import nn

mindone_lib_path = os.path.abspath("../../")
sys.path.insert(0, mindone_lib_path)
from mindone.utils.amp import auto_mixed_precision

from mindone.utils.config import str2bool
from mindone.utils.logger import set_logger

sys.path.append(".")

import cv2
from albumentations import Compose, Lambda, Resize, ToFloat
from opensora.models import CausalVAEModelWrapper
from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate
from opensora.utils.ms_utils import init_env
from opensora.models.causalvideovae import ae_wrapper
from opensora.models.causalvideovae.model.registry import ModelRegistry
from opensora.npu_config import npu_config
from opensora.utils.utils import get_precision

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,7 +49,7 @@ def preprocess(image, height: int = 128, width: int = 128):

image = video_transform(image=image)["image"] # (h w c)
# (h w c) -> (c h w) -> (c t h w)
image = np.transpose(image, (2, 1, 0))[:, None, :, :]
image = np.transpose(image, (2, 0, 1))[:, None, :, :]
return image


Expand All @@ -75,27 +64,44 @@ def transform_to_rgb(x, rescale_to_uint8=True):
def main(args):
image_path = args.image_path
short_size = args.short_size
init_env(
mode=args.mode,
device_target=args.device,
precision_mode=args.precision_mode,
jit_level=args.jit_level,
jit_syntax_level=args.jit_syntax_level,
)
npu_config.set_npu_env(args)
npu_config.print_ops_dtype_info()

set_logger(name="", output_dir=args.output_path, rank=0)

dtype = get_precision(args.precision)
if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint):
logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}")
state_dict = ms.load_checkpoint(args.ms_checkpoint)
# rm 'network.' prefix

state_dict = dict(
[k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items()
)
state_dict = dict([k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in state_dict.items())
else:
state_dict = None
kwarg = {"state_dict": state_dict, "use_safetensors": True}
vae = CausalVAEModelWrapper(args.ae_path, **kwarg)
vae = None
if args.model_config is not None:
assert os.path.exists(args.model_config), f"`model_config` does not exist! {args.model_config}"
pattern = r"^([A-Za-z]+)Model"
if re.match(pattern, args.ae):
model_name = re.match(pattern, args.ae).group(1)
model_cls = ModelRegistry.get_model(model_name)
vae = model_cls.from_config(args.model_config, dtype=dtype)
if args.ms_checkpoint is None or not os.path.exists(args.ms_checkpoint):
logger.warning(
"VAE is randomly initialized. The inference results may be incorrect! Check `ms_checkpoint`!"
)

else:
logger.warning(f"Incorrect ae name, must be one of {ae_wrapper.keys()}")

kwarg = {
"state_dict": state_dict,
"use_safetensors": True,
"dtype": dtype,
"vae": vae,
}
vae = ae_wrapper[args.ae](args.ae_path, **kwarg)

if args.enable_tiling:
vae.vae.enable_tiling()
Expand All @@ -104,25 +110,11 @@ def main(args):
vae.set_train(False)
for param in vae.get_parameters():
param.requires_grad = False
if args.precision in ["fp16", "bf16"]:
amp_level = "O2"
dtype = get_precision(args.precision)
if dtype == ms.float16:
custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else []
else:
custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate]

vae = auto_mixed_precision(vae, amp_level, dtype, custom_fp32_cells=custom_fp32_cells)
logger.info(
f"Set mixed precision to {amp_level} with dtype={args.precision}, custom fp32_cells {custom_fp32_cells}"
)
elif args.precision == "fp32":
dtype = get_precision(args.precision)
else:
raise ValueError(f"Unsupported precision {args.precision}")
input_x = np.array(Image.open(image_path)) # (h w c)
assert input_x.shape[2], f"Expect the input image has three channels, but got shape {input_x.shape}"
x_vae = preprocess(input_x, short_size, short_size) # use image as a single-frame video

x_vae = ms.Tensor(x_vae, dtype).unsqueeze(0) # b c t h w
latents = vae.encode(x_vae)
latents = latents.to(dtype)
Expand All @@ -147,14 +139,15 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--image_path", type=str, default="")
parser.add_argument("--rec_path", type=str, default="")
parser.add_argument("--ae", type=str, default="WFVAEModel_D8_4x8x8", choices=ae_wrapper.keys())
parser.add_argument("--ae_path", type=str, default="results/pretrained")
parser.add_argument("--ms_checkpoint", type=str, default=None)
parser.add_argument("--short_size", type=int, default=336)
parser.add_argument("--tile_overlap_factor", type=float, default=0.25)
parser.add_argument("--tile_sample_min_size", type=int, default=256)
parser.add_argument("--enable_tiling", action="store_true")
# ms related
parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode")
parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode")
parser.add_argument(
"--precision",
default="bf16",
Expand Down Expand Up @@ -188,5 +181,8 @@ def main(args):
parser.add_argument(
"--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax"
)
parser.add_argument(
"--model_config", type=str, default=None, help="The model config file for initiating vae model."
)
args = parser.parse_args()
main(args)
102 changes: 42 additions & 60 deletions examples/opensora_pku/examples/rec_video.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,31 @@
"""
Run causal vae reconstruction on a given video.
Usage example:
python examples/rec_video.py \
--ae_path path/to/vae/ckpt \
--video_path test.mp4 \
--rec_path rec.mp4 \
--sample_rate 1 \
--num_frames 65 \
--height 480 \
--width 640 \
"""
import argparse
import logging
import os
import random
import re
import sys

import numpy as np
from decord import VideoReader, cpu
from PIL import Image

import mindspore as ms
from mindspore import nn

mindone_lib_path = os.path.abspath("../../")
sys.path.insert(0, mindone_lib_path)
from mindone.utils.amp import auto_mixed_precision
from mindone.utils.config import str2bool
from mindone.utils.logger import set_logger
from mindone.visualize.videos import save_videos

sys.path.append(".")
from functools import partial

import cv2
from albumentations import Compose, Lambda, Resize, ToFloat
from opensora.dataset.transform import center_crop_th_tw
from opensora.models import CausalVAEModelWrapper
from opensora.models.causalvideovae.model.modules.updownsample import TrilinearInterpolate
from opensora.utils.ms_utils import init_env
from opensora.models.causalvideovae import ae_wrapper
from opensora.models.causalvideovae.model.registry import ModelRegistry
from opensora.npu_config import npu_config
from opensora.utils.utils import get_precision
from opensora.utils.video_utils import save_videos

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -117,56 +103,55 @@ def transform_to_rgb(x, rescale_to_uint8=True):


def main(args):
init_env(
mode=args.mode,
device_target=args.device,
precision_mode=args.precision_mode,
jit_level=args.jit_level,
jit_syntax_level=args.jit_syntax_level,
)

npu_config.set_npu_env(args)
npu_config.print_ops_dtype_info()
dtype = get_precision(args.precision)
set_logger(name="", output_dir=args.output_path, rank=0)
if args.ms_checkpoint is not None and os.path.exists(args.ms_checkpoint):
logger.info(f"Run inference with MindSpore checkpoint {args.ms_checkpoint}")
state_dict = ms.load_checkpoint(args.ms_checkpoint)
# rm 'network.' prefix

state_dict = dict(
[k.replace("autoencoder.", "") if k.startswith("autoencoder.") else k, v] for k, v in state_dict.items()
)
state_dict = dict([k.replace("_backbone.", "") if "_backbone." in k else k, v] for k, v in state_dict.items())
else:
state_dict = None
kwarg = {"state_dict": state_dict, "use_safetensors": True}
vae = CausalVAEModelWrapper(args.ae_path, **kwarg)

vae = None
if args.model_config is not None:
assert os.path.exists(args.model_config), f"`model_config` does not exist! {args.model_config}"
pattern = r"^([A-Za-z]+)Model"
if re.match(pattern, args.ae):
model_name = re.match(pattern, args.ae).group(1)
model_cls = ModelRegistry.get_model(model_name)
vae = model_cls.from_config(args.model_config, dtype=dtype)
if args.ms_checkpoint is None or not os.path.exists(args.ms_checkpoint):
logger.warning(
"VAE is randomly initialized. The inference results may be incorrect! Check `ms_checkpoint`!"
)

else:
logger.warning(f"Incorrect ae name, must be one of {ae_wrapper.keys()}")

kwarg = {
"state_dict": state_dict,
"use_safetensors": True,
"dtype": dtype,
"vae": vae,
}
vae = ae_wrapper[args.ae](args.ae_path, **kwarg)

if args.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = args.tile_overlap_factor
if args.save_memory:
vae.vae.tile_sample_min_size = args.tile_sample_min_size
vae.vae.tile_latent_min_size = 32
vae.vae.tile_sample_min_size_t = 29
vae.vae.tile_latent_min_size_t = 8

vae.set_train(False)
for param in vae.get_parameters():
param.requires_grad = False
if args.precision in ["fp16", "bf16"]:
amp_level = "O2"
dtype = get_precision(args.precision)
if dtype == ms.float16:
custom_fp32_cells = [nn.GroupNorm] if args.vae_keep_gn_fp32 else []
else:
custom_fp32_cells = [nn.AvgPool2d, TrilinearInterpolate]
vae = auto_mixed_precision(vae, amp_level, dtype, custom_fp32_cells=custom_fp32_cells)
logger.info(
f"Set mixed precision to {amp_level} with dtype={args.precision}, custom fp32_cells {custom_fp32_cells}"
)
elif args.precision == "fp32":
dtype = get_precision(args.precision)
else:
raise ValueError(f"Unsupported precision {args.precision}")

x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.height, args.width)

x_vae = ms.Tensor(x_vae, dtype).unsqueeze(0) # b c t h w
latents = vae.encode(x_vae)
latents = latents.to(dtype)
Expand Down Expand Up @@ -202,19 +187,18 @@ def main(args):
parser = argparse.ArgumentParser()
parser.add_argument("--video_path", type=str, default="")
parser.add_argument("--rec_path", type=str, default="")
parser.add_argument("--ae", type=str, default="")
parser.add_argument("--ae_path", type=str, default="results/pretrained")
parser.add_argument("--ms_checkpoint", type=str, default=None)
parser.add_argument("--fps", type=int, default=30)
parser.add_argument("--height", type=int, default=336)
parser.add_argument("--width", type=int, default=336)
parser.add_argument("--num_frames", type=int, default=65)
parser.add_argument("--sample_rate", type=int, default=1)
parser.add_argument("--tile_overlap_factor", type=float, default=0.25)
parser.add_argument("--tile_sample_min_size", type=int, default=256)
parser.add_argument("--enable_tiling", action="store_true")
parser.add_argument("--save_memory", action="store_true")
parser.add_argument("--tile_overlap_factor", type=float, default=0.25)
# ms related
parser.add_argument("--mode", default=0, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode")
parser.add_argument("--mode", default=1, type=int, help="Specify the mode: 0 for graph mode, 1 for pynative mode")
parser.add_argument(
"--precision",
default="bf16",
Expand All @@ -223,12 +207,7 @@ def main(args):
help="mixed precision type, if fp32, all layer precision is float32 (amp_level=O0), \
if bf16 or fp16, amp_level==O2, part of layers will compute in bf16 or fp16 such as matmul, dense, conv.",
)
parser.add_argument(
"--vae_keep_gn_fp32",
default=False,
type=str2bool,
help="whether keep GroupNorm in fp32. Defaults to False in inference, better to set to True when training vae",
)

parser.add_argument("--device", type=str, default="Ascend", help="Ascend or GPU")
parser.add_argument(
"--precision_mode",
Expand All @@ -248,5 +227,8 @@ def main(args):
parser.add_argument(
"--jit_syntax_level", default="strict", choices=["strict", "lax"], help="Set jit syntax level: strict or lax"
)
parser.add_argument(
"--model_config", type=str, default=None, help="The model config file for initiating vae model."
)
args = parser.parse_args()
main(args)
Loading
Loading