From de77902bfca4077433fc19f37732c08ff3cb0c3f Mon Sep 17 00:00:00 2001 From: Didan Deng <33117903+wtomin@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:30:50 +0800 Subject: [PATCH] update new dataset v0.3.0 --- .../opensora_pku/opensora/dataset/__init__.py | 46 +- .../opensora_pku/opensora/dataset/loader.py | 36 +- .../opensora/dataset/t2v_datasets.py | 796 ++++++++++++------ .../opensora/dataset/transform.py | 230 ++++- .../opensora/train/train_t2v_diffusers.py | 63 +- .../opensora/utils/dataset_utils.py | 380 ++++++--- .../scripts/train_data/video_data_v1_2.txt | 2 +- examples/opensora_pku/tests/test_data.py | 48 +- examples/opensora_pku/tests/test_data.sh | 11 +- 9 files changed, 1103 insertions(+), 509 deletions(-) diff --git a/examples/opensora_pku/opensora/dataset/__init__.py b/examples/opensora_pku/opensora/dataset/__init__.py index d596a1d9a9..8c38d20fa6 100644 --- a/examples/opensora_pku/opensora/dataset/__init__.py +++ b/examples/opensora_pku/opensora/dataset/__init__.py @@ -6,7 +6,7 @@ from transformers import AutoTokenizer from .t2v_datasets import T2V_dataset -from .transform import TemporalRandomCrop, center_crop_th_tw, spatial_stride_crop_video, maxhxw_resize +from .transform import TemporalRandomCrop, center_crop_th_tw, maxhxw_resize, spatial_stride_crop_video def getdataset(args, dataset_file): @@ -18,16 +18,11 @@ def norm_func_albumentation(image, **kwargs): mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} targets = {"image{}".format(i): "image" for i in range(args.num_frames)} - resize_topcrop = [ - Lambda( - name="crop_topcrop", - image=partial(center_crop_th_tw, th=args.max_height, tw=args.max_width, top_crop=True), - p=1.0, - ), - Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]), - ] + if args.force_resolution: - assert (args.max_height is not None) and (args.max_width is not None), "set max_height and max_width for fixed resolution" + assert (args.max_height is not None) and ( + args.max_width is not None + ), "set max_height and max_width for fixed resolution" resize = [ Lambda( name="crop_centercrop", @@ -36,7 +31,7 @@ def norm_func_albumentation(image, **kwargs): ), Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]), ] - else: # dynamic resolution + else: # dynamic resolution assert args.max_hxw is not None, "set max_hxw for dynamic resolution" resize = [ Lambda( @@ -46,7 +41,7 @@ def norm_func_albumentation(image, **kwargs): ), Lambda( name="spatial_stride_crop", - image=partial(spatial_stride_crop_video, stride=args.hw_stride), # default stride=32 + image=partial(spatial_stride_crop_video, stride=args.hw_stride), # default stride=32 p=1.0, ), ] @@ -55,35 +50,20 @@ def norm_func_albumentation(image, **kwargs): [*resize, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)], additional_targets=targets, ) - transform_topcrop = Compose( - [*resize_topcrop, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)], - additional_targets=targets, - ) - tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir) + tokenizer_1 = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir) + tokenizer_2 = None if args.text_encoder_name_2 is not None: tokenizer_2 = AutoTokenizer.from_pretrained(args.text_encoder_name_2, cache_dir=args.cache_dir) if args.dataset == "t2v": return T2V_dataset( - dataset_file, - num_frames=args.num_frames, - train_fps=args.train_fps, - use_image_num=args.use_image_num, - use_img_from_vid=args.use_img_from_vid, - model_max_length=args.model_max_length, - cfg=args.cfg, - speed_factor=args.speed_factor, - max_height=args.max_height, - max_width=args.max_width, - drop_short_ratio=args.drop_short_ratio, - dataloader_num_workers=args.dataloader_num_workers, - text_encoder_name=args.text_encoder_name_1, # TODO: update with 2nd text encoder - return_text_emb=args.text_embed_cache, + args, transform=transform, temporal_sample=temporal_sample, - tokenizer=tokenizer, - transform_topcrop=transform_topcrop, + tokenizer_1=tokenizer_1, + tokenizer_2=tokenizer_2, + return_text_emb=args.text_embed_cache, ) elif args.dataset == "inpaint" or args.dataset == "i2v": raise NotImplementedError diff --git a/examples/opensora_pku/opensora/dataset/loader.py b/examples/opensora_pku/opensora/dataset/loader.py index 832dbd55fa..d53372a0dc 100644 --- a/examples/opensora_pku/opensora/dataset/loader.py +++ b/examples/opensora_pku/opensora/dataset/loader.py @@ -26,6 +26,7 @@ def create_dataloader( enable_modelarts=False, collate_fn=None, sampler=None, + batch_sampler=None, ): datalen = len(dataset) @@ -46,6 +47,7 @@ def create_dataloader( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + batch_sampler=batch_sampler, ) dl = GeneratorDataset( loader, @@ -62,13 +64,24 @@ def create_dataloader( def build_dataloader( - dataset, datalens, collate_fn, batch_size, device_num, rank_id=0, sampler=None, shuffle=True, drop_last=True + dataset, + datalens, + collate_fn, + batch_size, + device_num, + rank_id=0, + sampler=None, + batch_sampler=None, + shuffle=True, + drop_last=True, ): - if sampler is None: - sampler = BatchSampler(datalens, batch_size=batch_size, device_num=device_num, shuffle=shuffle) + if batch_sampler is None: + batch_sampler = BatchSampler(datalens, batch_size=batch_size, device_num=device_num, shuffle=shuffle) loader = DataLoader( dataset, - batch_sampler=sampler, + batch_size=batch_size, + sampler=sampler, + batch_sampler=batch_sampler, collate_fn=collate_fn, device_num=device_num, drop_last=drop_last, @@ -107,14 +120,25 @@ def __len__(self): class DataLoader: """DataLoader""" - def __init__(self, dataset, batch_sampler, collate_fn, device_num=1, drop_last=True, rank_id=0): + def __init__( + self, + dataset, + batch_size, + sampler=None, + batch_sampler=None, + collate_fn=None, + device_num=1, + drop_last=True, + rank_id=0, + ): self.dataset = dataset + self.sampler = sampler self.batch_sampler = batch_sampler self.collat_fn = collate_fn self.device_num = device_num self.rank_id = rank_id self.drop_last = drop_last - self.batch_size = len(next(iter(self.batch_sampler))) + self.batch_size = batch_size def __iter__(self): self.step_index = 0 diff --git a/examples/opensora_pku/opensora/dataset/t2v_datasets.py b/examples/opensora_pku/opensora/dataset/t2v_datasets.py index 2909d361b0..013d4be5b1 100644 --- a/examples/opensora_pku/opensora/dataset/t2v_datasets.py +++ b/examples/opensora_pku/opensora/dataset/t2v_datasets.py @@ -3,15 +3,27 @@ import logging import math import os +import pickle import random +import time from collections import Counter +from concurrent.futures import ThreadPoolExecutor from os.path import join as opj from pathlib import Path +import cv2 +import decord import numpy as np -from opensora.utils.dataset_utils import DecordInit +from opensora.dataset.transform import ( + add_aesthetic_notice_image, + add_aesthetic_notice_video, + calculate_statistics, + get_params, + maxhwresize, +) from opensora.utils.utils import text_preprocessing from PIL import Image +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -84,10 +96,40 @@ def get_item(self, work_info): dataset_prog = DataSetProg() -def find_closest_y(x, vae_stride_t=4, model_ds_t=4): - if x < 29: +class DecordDecoder(object): + def __init__(self, url, num_threads=1): + self.num_threads = num_threads + self.ctx = decord.cpu(0) + self.reader = decord.VideoReader(url, ctx=self.ctx, num_threads=self.num_threads) + + def get_avg_fps(self): + return self.reader.get_avg_fps() if self.reader.get_avg_fps() > 0 else 30.0 + + def get_num_frames(self): + return len(self.reader) + + def get_height(self): + return self.reader[0].shape[0] if self.get_num_frames() > 0 else 0 + + def get_width(self): + return self.reader[0].shape[1] if self.get_num_frames() > 0 else 0 + + # output shape [T, H, W, C] + def get_batch(self, frame_indices): + try: + # frame_indices[0] = 1000 + video_data = self.reader.get_batch(frame_indices).asnumpy() + return video_data + except Exception as e: + print("get_batch execption:", e) + return None + + +def find_closest_y(x, vae_stride_t=4, model_ds_t=1): + min_num_frames = 29 + if x < min_num_frames: return -1 - for y in range(x, 12, -1): + for y in range(x, min_num_frames - 1, -1): if (y - 1) % vae_stride_t == 0 and ((y - 1) // vae_stride_t + 1) % model_ds_t == 0: return y return -1 @@ -102,79 +144,334 @@ def filter_resolution(h, w, max_h_div_w_ratio=17 / 16, min_h_div_w_ratio=8 / 16) class T2V_dataset: def __init__( self, - data, - num_frames: int = 29, - train_fps: int = 24, - use_image_num: int = 0, - use_img_from_vid: bool = False, - model_max_length: int = 512, - cfg: float = 0.1, - speed_factor: float = 1.0, - max_height: int = 480, - max_width: int = 640, - drop_short_ratio: float = 1.0, - dataloader_num_workers: int = 10, - text_encoder_name: str = "google/mt5-xxl", - transform=None, - temporal_sample=None, - tokenizer=None, - transform_topcrop=None, + args, + transform, + temporal_sample, + tokenizer_1, + tokenizer_2, filter_nonexistent=True, return_text_emb=False, ): - self.data = data - self.num_frames = num_frames - self.train_fps = train_fps - self.use_image_num = use_image_num - self.use_img_from_vid = use_img_from_vid + self.data = args.data + self.num_frames = args.num_frames + self.train_fps = args.train_fps self.transform = transform - self.transform_topcrop = transform_topcrop self.temporal_sample = temporal_sample - self.tokenizer = tokenizer - self.model_max_length = model_max_length - self.cfg = cfg - self.speed_factor = speed_factor - self.max_height = max_height - self.max_width = max_width - self.drop_short_ratio = drop_short_ratio + self.tokenizer_1 = tokenizer_1 + self.tokenizer_2 = tokenizer_2 + self.model_max_length = args.model_max_length + self.cfg = args.cfg + self.speed_factor = args.speed_factor + self.max_height = args.max_height + self.max_width = args.max_width + self.drop_short_ratio = args.drop_short_ratio + self.hw_stride = args.hw_stride + self.force_resolution = args.force_resolution + self.max_hxw = args.max_hxw + self.min_hxw = args.min_hxw + self.sp_size = args.sp_size assert self.speed_factor >= 1 - self.v_decoder = DecordInit() + self.video_reader = "decord" if args.use_decord else "opencv" + self.ae_stride_t = args.ae_stride_t + self.total_batch_size = args.total_batch_size + self.seed = args.seed + self.generator = np.random.default_rng(self.seed) + self.hw_aspect_thr = 2.0 # just a threshold + self.too_long_factor = 10.0 # set this threshold larger for longer video datasets self.filter_nonexistent = filter_nonexistent self.return_text_emb = return_text_emb if self.return_text_emb and self.cfg > 0: logger.warning(f"random text drop ratio {self.cfg} will be ignored when text embeddings are cached.") self.duration_threshold = 100.0 - self.support_Chinese = True - if not ("mt5" in text_encoder_name): - self.support_Chinese = False + self.support_Chinese = False + if "mt5" in args.text_encoder_name_1: + self.support_Chinese = True + if args.text_encoder_name_2 is not None and "mt5" in args.text_encoder_name_2: + self.support_Chinese = True + s = time.time() + cap_list, self.sample_size, self.shape_idx_dict = self.define_frame_index(self.data) + e = time.time() + print(f"Build data time: {e-s}") + self.lengths = self.sample_size - cap_list = self.get_cap_list() - if self.filter_nonexistent: - cap_list = self.filter_nonexistent_files(cap_list) + n_elements = len(cap_list) + dataset_prog.set_cap_list(args.dataloader_num_workers, cap_list, n_elements) + print(f"Data length: {len(dataset_prog.cap_list)}") + self.executor = ThreadPoolExecutor(max_workers=1) + self.timeout = 60 - assert len(cap_list) > 0 - cap_list, self.sample_num_frames = self.define_frame_index(cap_list) - self.lengths = self.sample_num_frames + def define_frame_index(self, data): + shape_idx_dict = {} + new_cap_list = [] + sample_size = [] + aesthetic_score = [] + cnt_vid = 0 + cnt_img = 0 + cnt_too_long = 0 + cnt_too_short = 0 + cnt_no_cap = 0 + cnt_no_resolution = 0 + cnt_no_aesthetic = 0 + cnt_img_res_mismatch_stride = 0 + cnt_vid_res_mismatch_stride = 0 + cnt_img_aspect_mismatch = 0 + cnt_vid_aspect_mismatch = 0 + cnt_img_res_too_small = 0 + cnt_vid_res_too_small = 0 + cnt_vid_after_filter = 0 + cnt_img_after_filter = 0 + cnt_no_existent = 0 + cnt = 0 - n_elements = len(cap_list) - dataset_prog.set_cap_list(dataloader_num_workers, cap_list, n_elements) - - print(f"video length: {len(dataset_prog.cap_list)}", flush=True) - - def filter_nonexistent_files(self, cap_list): - indexes_to_remove = [] - for i, item in enumerate(cap_list): - path = item["path"] - if not os.path.exists(path): - second_path = path.replace("_resize1080p.mp4", ".mp4") - if os.path.exists(second_path): - cap_list[i]["path"] = second_path + with open(data, "r") as f: + folder_anno = [i.strip().split(",") for i in f.readlines() if len(i.strip()) > 0] + assert len(folder_anno) > 0, "input dataset file cannot be empty!" + for input_dataset in tqdm(folder_anno): + text_embed_folder_1, text_embed_folder_2 = None, None + if len(input_dataset) == 2: + assert not self.return_text_emb, "Train without text embedding cache!" + elif len(input_dataset) == 3: + text_embed_folder_1 = input_dataset[1] + sub_root, anno = input_dataset[0], input_dataset[-1] + elif len(input_dataset) == 4: + text_embed_folder_1 = input_dataset[1] + text_embed_folder_2 = input_dataset[2] + sub_root, anno = input_dataset[0], input_dataset[-1] + else: + raise ValueError("Not supported input dataset file!") + + print(f"Building {anno}...") + if anno.endswith(".json"): + with open(anno, "r") as f: + sub_list = json.load(f) + elif anno.endswith(".pkl"): + with open(anno, "rb") as f: + sub_list = pickle.load(f) + for index, i in enumerate(tqdm(sub_list)): + cnt += 1 + path = os.path.join(sub_root, i["path"]) + if self.filter_nonexistent: + if not os.path.exists(path): + cnt_no_existent += 1 + continue + + if self.return_text_emb: + text_embeds_paths = self.get_text_embed_file_path(i) + if text_embed_folder_1 is not None: + i["text_embed_path_1"] = [opj(text_embed_folder_1, tp) for tp in text_embeds_paths] + if any([not os.path.exists(p) for p in i["text_embed_path_1"]]): + cnt_no_existent += 1 + continue + if text_embed_folder_2 is not None: + i["text_embed_path_2"] = [opj(text_embed_folder_2, tp) for tp in text_embeds_paths] + if any([not os.path.exists(p) for p in i["text_embed_path_2"]]): + cnt_no_existent += 1 + continue + + if path.endswith(".mp4"): + cnt_vid += 1 + elif path.endswith(".jpg"): + cnt_img += 1 + + # ======no aesthetic===== + if i.get("aesthetic", None) is None or i.get("aes", None) is None: + cnt_no_aesthetic += 1 + else: + aesthetic_score.append(i.get("aesthetic", None) or i.get("aes", None)) + + # ======no caption===== + cap = i.get("cap", None) + if cap is None: + cnt_no_cap += 1 + continue + + # ======resolution mismatch===== + i["path"] = path + assert ( + "resolution" in i + ), "Expect that each element in the provided datset should have a item named `resolution`" + if i.get("resolution", None) is None: + cnt_no_resolution += 1 + continue else: - indexes_to_remove.append(i) - cap_list = [item for i, item in enumerate(cap_list) if i not in indexes_to_remove] - logger.info(f"Nonexistent files: {len(indexes_to_remove)}") - return cap_list + assert ( + "height" in i["resolution"] and "width" in i["resolution"] + ), "Expect that each element has `resolution: \\{'height': int, 'width': int,\\}`" + if i["resolution"].get("height", None) is None or i["resolution"].get("width", None) is None: + cnt_no_resolution += 1 + continue + else: + height, width = i["resolution"]["height"], i["resolution"]["width"] + if not self.force_resolution: + if height <= 0 or width <= 0: + cnt_no_resolution += 1 + continue + + tr_h, tr_w = maxhwresize(height, width, self.max_hxw) + _, _, sample_h, sample_w = get_params(tr_h, tr_w, self.hw_stride) + + if sample_h <= 0 or sample_w <= 0: + if path.endswith(".mp4"): + cnt_vid_res_mismatch_stride += 1 + elif path.endswith(".jpg"): + cnt_img_res_mismatch_stride += 1 + continue + + # filter min_hxw + if sample_h * sample_w < self.min_hxw: + if path.endswith(".mp4"): + cnt_vid_res_too_small += 1 + elif path.endswith(".jpg"): + cnt_img_res_too_small += 1 + continue + + # filter aspect + is_pick = filter_resolution( + sample_h, + sample_w, + max_h_div_w_ratio=self.hw_aspect_thr, + min_h_div_w_ratio=1 / self.hw_aspect_thr, + ) + if not is_pick: + if path.endswith(".mp4"): + cnt_vid_aspect_mismatch += 1 + elif path.endswith(".jpg"): + cnt_img_aspect_mismatch += 1 + continue + + i["resolution"].update(dict(sample_height=sample_h, sample_width=sample_w)) + + else: + aspect = self.max_height / self.max_width + is_pick = filter_resolution( + height, + width, + max_h_div_w_ratio=self.hw_aspect_thr * aspect, + min_h_div_w_ratio=1 / self.hw_aspect_thr * aspect, + ) + if not is_pick: + if path.endswith(".mp4"): + cnt_vid_aspect_mismatch += 1 + elif path.endswith(".jpg"): + cnt_img_aspect_mismatch += 1 + continue + sample_h, sample_w = self.max_height, self.max_width + + i["resolution"].update(dict(sample_height=sample_h, sample_width=sample_w)) + + if path.endswith(".mp4"): + fps = i.get("fps", 24) + # max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration. + assert ( + "num_frames" in i + ), "Expect that each element in the provided datset should have a item named `num_frames`" + if i["num_frames"] > self.too_long_factor * ( + self.num_frames * fps / self.train_fps * self.speed_factor + ): # too long video is not suitable for this training stage (self.num_frames) + cnt_too_long += 1 + continue + + # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) + frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps + start_frame_idx = i.get("cut", [0])[0] + i["start_frame_idx"] = start_frame_idx + frame_indices = np.arange( + start_frame_idx, start_frame_idx + i["num_frames"], frame_interval + ).astype(int) + frame_indices = frame_indices[frame_indices < start_frame_idx + i["num_frames"]] + + # comment out it to enable dynamic frames training + if len(frame_indices) < self.num_frames and self.generator.random() < self.drop_short_ratio: + cnt_too_short += 1 + continue + + # too long video will be temporal-crop randomly + if len(frame_indices) > self.num_frames: + begin_index, end_index = self.temporal_sample(len(frame_indices)) + frame_indices = frame_indices[begin_index:end_index] + # frame_indices = frame_indices[:self.num_frames] # head crop + # to find a suitable end_frame_idx, to ensure we do not need pad video + end_frame_idx = find_closest_y( + len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size + ) + if end_frame_idx == -1: # too short that can not be encoded exactly by videovae + cnt_too_short += 1 + continue + frame_indices = frame_indices[:end_frame_idx] + + i["sample_frame_index"] = frame_indices.tolist() + + new_cap_list.append(i) + cnt_vid_after_filter += 1 + + elif path.endswith(".jpg"): # image + cnt_img_after_filter += 1 + i["sample_frame_index"] = [0] + new_cap_list.append(i) + + else: + raise NameError( + f"Unknown file extention {path.split('.')[-1]}, only support .mp4 for video and .jpg for image" + ) + + pre_define_shape = f"{len(i['sample_frame_index'])}x{sample_h}x{sample_w}" + sample_size.append(pre_define_shape) + # if shape_idx_dict.get(pre_define_shape, None) is None: + # shape_idx_dict[pre_define_shape] = [index] + # else: + # shape_idx_dict[pre_define_shape].append(index) + counter = Counter(sample_size) + counter_cp = counter + if not self.force_resolution and self.max_hxw is not None and self.min_hxw is not None: + assert all( + [np.prod(np.array(k.split("x")[1:]).astype(np.int32)) <= self.max_hxw for k in counter_cp.keys()] + ) + assert all( + [np.prod(np.array(k.split("x")[1:]).astype(np.int32)) >= self.min_hxw for k in counter_cp.keys()] + ) + + len_before_filter_major = len(sample_size) + filter_major_num = 4 * self.total_batch_size + new_cap_list, sample_size = zip( + *[[i, j] for i, j in zip(new_cap_list, sample_size) if counter[j] >= filter_major_num] + ) + for idx, shape in enumerate(sample_size): + if shape_idx_dict.get(shape, None) is None: + shape_idx_dict[shape] = [idx] + else: + shape_idx_dict[shape].append(idx) + cnt_filter_minority = len_before_filter_major - len(sample_size) + counter = Counter(sample_size) + + print( + f"no_cap: {cnt_no_cap}, no_resolution: {cnt_no_resolution}\n" + f"too_long: {cnt_too_long}, too_short: {cnt_too_short}\n" + f"cnt_img_res_mismatch_stride: {cnt_img_res_mismatch_stride}, cnt_vid_res_mismatch_stride: {cnt_vid_res_mismatch_stride}\n" + f"cnt_img_res_too_small: {cnt_img_res_too_small}, cnt_vid_res_too_small: {cnt_vid_res_too_small}\n" + f"cnt_img_aspect_mismatch: {cnt_img_aspect_mismatch}, cnt_vid_aspect_mismatch: {cnt_vid_aspect_mismatch}\n" + f"cnt_filter_minority: {cnt_filter_minority}\n" + f"cnt_no_existent: {cnt_no_existent}\n" + if self.filter_nonexistent + else "" + f"Counter(sample_size): {counter}\n" + f"cnt_vid: {cnt_vid}, cnt_vid_after_filter: {cnt_vid_after_filter}, use_ratio: {round(cnt_vid_after_filter/(cnt_vid+1e-6), 5)*100}%\n" + f"cnt_img: {cnt_img}, cnt_img_after_filter: {cnt_img_after_filter}, use_ratio: {round(cnt_img_after_filter/(cnt_img+1e-6), 5)*100}%\n" + f"before filter: {cnt}, after filter: {len(new_cap_list)}, use_ratio: {round(len(new_cap_list)/cnt, 5)*100}%" + ) + # import ipdb;ipdb.set_trace() + + if len(aesthetic_score) > 0: + stats_aesthetic = calculate_statistics(aesthetic_score) + print( + f"before filter: {cnt}, after filter: {len(new_cap_list)}\n" + f"aesthetic_score: {len(aesthetic_score)}, cnt_no_aesthetic: {cnt_no_aesthetic}\n" + f"{len([i for i in aesthetic_score if i>=5.75])} > 5.75, 4.5 > {len([i for i in aesthetic_score if i<=4.5])}\n" + f"Mean: {stats_aesthetic['mean']}, Var: {stats_aesthetic['variance']}, Std: {stats_aesthetic['std_dev']}\n" + f"Min: {stats_aesthetic['min']}, Max: {stats_aesthetic['max']}" + ) + + return new_cap_list, sample_size, shape_idx_dict def set_checkpoint(self, n_used_elements): for i in range(len(dataset_prog.n_used_elements)): @@ -185,16 +482,16 @@ def __len__(self): def __getitem__(self, idx): try: - data = self.get_data(idx) + future = self.executor.submit(self.get_data, idx) + data = future.result(timeout=self.timeout) + # data = self.get_data(idx) return data except Exception as e: - logger.info(f"Error with {e}") - # 打印异常堆栈 - if idx in dataset_prog.cap_list: - logger.info(f"Caught an exception! {dataset_prog.cap_list[idx]}") - # traceback.print_exc() - # traceback.print_stack() - return self.__getitem__(random.randint(0, self.__len__() - 1)) + if len(str(e)) < 2: + e = f"TimeoutError, {self.timeout}s timeout occur with {dataset_prog.cap_list[idx]['path']}" + print(f"Error with {e}") + index_cand = self.shape_idx_dict[self.sample_size[idx]] # pick same shape + return self.__getitem__(random.choice(index_cand)) def get_data(self, idx): path = dataset_prog.cap_list[idx]["path"] @@ -204,202 +501,233 @@ def get_data(self, idx): return self.get_image(idx) def get_video(self, idx): - video_path = dataset_prog.cap_list[idx]["path"] + video_data = dataset_prog.cap_list[idx] + video_path = video_data["path"] assert os.path.exists(video_path), f"file {video_path} do not exist!" - frame_indice = dataset_prog.cap_list[idx]["sample_frame_index"] - video = self.decord_read(video_path, predefine_num_frames=len(frame_indice)) # (T H W C) - - h, w = video.shape[1:3] - # NOTE: not suitable for 1:1 training in v1.3 - # assert h / w <= 17 / 16 and h / w >= 8 / 16, ( - # f"Only videos with a ratio (h/w) less than 17/16 and more than 8/16 are supported. But video ({video_path}) " - # + f"found ratio is {round(h / w, 2)} with the shape of {video.shape}" - # ) + sample_h = video_data["resolution"]["sample_height"] + sample_w = video_data["resolution"]["sample_width"] + if self.video_reader == "decord": + video = self.decord_read(video_data) + elif self.video_reader == "opencv": + video = self.opencv_read(video_data) + else: + NotImplementedError(f"Found {self.video_reader}, but support decord or opencv") + + h, w = video.shape[1:3] # (T, H, W, C) input_videos = {"image": video[0]} input_videos.update(dict([(f"image{i}", video[i + 1]) for i in range(len(video) - 1)])) output_videos = self.transform(**input_videos) video = np.stack([v for _, v in output_videos.items()], axis=0).transpose(3, 0, 1, 2) # T H W C -> C T H W - + assert ( + video.shape[2] == sample_h and video.shape[3] == sample_w + ), f"sample_h ({sample_h}), sample_w ({sample_w}), video ({video.shape})" # get token ids and attention mask if not self.return_text_emb if not self.return_text_emb: - text = dataset_prog.cap_list[idx]["cap"] + text = video_data["cap"] if not isinstance(text, list): text = [text] text = [random.choice(text)] + if video_data.get("aesthetic", None) is not None or video_data.get("aes", None) is not None: + aes = video_data.get("aesthetic", None) or video_data.get("aes", None) + text = [add_aesthetic_notice_video(text[0], aes)] + text = text_preprocessing(text, support_Chinese=self.support_Chinese) + + text = text if random.random() > self.cfg else "" - text = text_preprocessing(text, support_Chinese=self.support_Chinese) if random.random() > self.cfg else "" - text_tokens_and_mask = self.tokenizer( + text_tokens_and_mask_1 = self.tokenizer_1( text, max_length=self.model_max_length, padding="max_length", truncation=True, return_attention_mask=True, add_special_tokens=True, - return_tensors="np", + return_tensors="pt", ) - input_ids = text_tokens_and_mask["input_ids"] - cond_mask = text_tokens_and_mask["attention_mask"] - return dict(pixel_values=video, input_ids=input_ids, cond_mask=cond_mask) + input_ids_1 = text_tokens_and_mask_1["input_ids"] + cond_mask_1 = text_tokens_and_mask_1["attention_mask"] + + input_ids_2, cond_mask_2 = None, None + if self.tokenizer_2 is not None: + text_tokens_and_mask_2 = self.tokenizer_2( + text, + max_length=self.tokenizer_2.model_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids_2 = text_tokens_and_mask_2["input_ids"] + cond_mask_2 = text_tokens_and_mask_2["attention_mask"] + return dict( + pixel_values=video, + input_ids_1=input_ids_1, + cond_mask_1=cond_mask_1, + input_ids_2=input_ids_2, + cond_mask_2=cond_mask_2, + ) + else: - text_embed_paths = dataset_prog.cap_list[idx]["text_embed_path"] - text_embed_path = random.choice(text_embed_paths) - text_emb, cond_mask = self.parse_text_emb(text_embed_path) - return dict(pixel_values=video, input_ids=text_emb, cond_mask=cond_mask) + if "text_embed_path_1" in video_data: + text_embed_paths = video_data["text_embed_path_1"] + text_embed_path = random.choice(text_embed_paths) + text_emb_1, cond_mask_1 = self.parse_text_emb(text_embed_path) + text_emb_2, cond_mask_2 = None, None + if "text_embed_path_2" in video_data: + text_embed_paths = video_data["text_embed_path_2"] + text_embed_path = random.choice(text_embed_paths) + text_emb_2, cond_mask_2 = self.parse_text_emb(text_embed_path) + return dict( + pixel_values=video, + input_ids_1=text_emb_1, + cond_mask_1=cond_mask_1, + input_ids_2=text_emb_2, + cond_mask_2=cond_mask_2, + ) def get_image(self, idx): image_data = dataset_prog.cap_list[idx] # [{'path': path, 'cap': cap}, ...] - + sample_h = image_data["resolution"]["sample_height"] + sample_w = image_data["resolution"]["sample_width"] # import ipdb;ipdb.set_trace() image = Image.open(image_data["path"]).convert("RGB") # [h, w, c] image = np.array(image) # [h, w, c] - image = ( - self.transform_topcrop(image=image)["image"] - if "human_images" in image_data["path"] - else self.transform(image=image)["image"] - ) + image = self.transform(image=image)["image"] # [h, w, c] -> [c h w] -> [C 1 H W] image = image.transpose(2, 0, 1)[:, None, ...] + assert ( + image.shape[2] == sample_h and image.shape[3] == sample_w + ), f"image_data: {image_data}, but found image {image.shape}" # get token ids and attention mask if not self.return_text_emb if not self.return_text_emb: caps = image_data["cap"] if isinstance(image_data["cap"], list) else [image_data["cap"]] caps = [random.choice(caps)] + if image_data.get("aesthetic", None) is not None or image_data.get("aes", None) is not None: + aes = image_data.get("aesthetic", None) or image_data.get("aes", None) + caps = [add_aesthetic_notice_image(caps[0], aes)] text = text_preprocessing(caps, support_Chinese=self.support_Chinese) - input_ids, cond_mask = [], [] text = text if random.random() > self.cfg else "" - text_tokens_and_mask = self.tokenizer( + + text_tokens_and_mask_1 = self.tokenizer_1( text, max_length=self.model_max_length, padding="max_length", truncation=True, return_attention_mask=True, add_special_tokens=True, - return_tensors="np", + return_tensors="pt", + ) + input_ids_1 = text_tokens_and_mask_1["input_ids"] # 1, l + cond_mask_1 = text_tokens_and_mask_1["attention_mask"] # 1, l + + input_ids_2, cond_mask_2 = None, None + if self.tokenizer_2 is not None: + text_tokens_and_mask_2 = self.tokenizer_2( + text, + max_length=self.tokenizer_2.model_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids_2 = text_tokens_and_mask_2["input_ids"] # 1, l + cond_mask_2 = text_tokens_and_mask_2["attention_mask"] # 1, l + + return dict( + pixel_values=image, + input_ids_1=input_ids_1, + cond_mask_1=cond_mask_1, + input_ids_2=input_ids_2, + cond_mask_2=cond_mask_2, ) - input_ids = text_tokens_and_mask["input_ids"] # 1, l - cond_mask = text_tokens_and_mask["attention_mask"] # 1, l - return dict(pixel_values=image, input_ids=input_ids, cond_mask=cond_mask) else: - text_embed_paths = dataset_prog.cap_list[idx]["text_embed_path"] - text_embed_path = random.choice(text_embed_paths) - text_emb, cond_mask = self.parse_text_emb(text_embed_path) - return dict(pixel_values=image, input_ids=text_emb, cond_mask=cond_mask) + if "text_embed_path_1" in image_data: + text_embed_paths = image_data["text_embed_path_1"] + text_embed_path = random.choice(text_embed_paths) + text_emb_1, cond_mask_1 = self.parse_text_emb(text_embed_path) + text_emb_2, cond_mask_2 = None, None + if "text_embed_path_2" in image_data: + text_embed_paths = image_data["text_embed_path_2"] + text_embed_path = random.choice(text_embed_paths) + text_emb_2, cond_mask_2 = self.parse_text_emb(text_embed_path) + return dict( + pixel_values=image, + input_ids_1=text_emb_1, + cond_mask_1=cond_mask_1, + input_ids_2=text_emb_2, + cond_mask_2=cond_mask_2, + ) - def define_frame_index(self, cap_list): - new_cap_list = [] - sample_num_frames = [] - cnt_too_long = 0 - cnt_too_short = 0 - cnt_no_cap = 0 - cnt_no_resolution = 0 - cnt_resolution_mismatch = 0 - cnt_movie = 0 - cnt_img = 0 - for i in cap_list: - path = i["path"] - cap = i.get("cap", None) - # ======no caption===== - if cap is None: - cnt_no_cap += 1 - continue - if path.endswith(".mp4"): - # ======no fps and duration===== - duration = i.get("duration", None) - fps = i.get("fps", None) - if fps is None or duration is None: - continue + def decord_read(self, video_data): + path = video_data["path"] + predefine_frame_indice = video_data["sample_frame_index"] + start_frame_idx = video_data["start_frame_idx"] + clip_total_frames = video_data["num_frames"] + fps = video_data["fps"] + s_x, e_x, s_y, e_y = video_data.get("crop", [None, None, None, None]) - # ======resolution mismatch===== - resolution = i.get("resolution", None) - if resolution is None: - cnt_no_resolution += 1 - continue - else: - if resolution.get("height", None) is None or resolution.get("width", None) is None: - cnt_no_resolution += 1 - continue - height, width = i["resolution"]["height"], i["resolution"]["width"] - aspect = self.max_height / self.max_width - hw_aspect_thr = 2.0 #NOTE: for 1:1 frame training - is_pick = filter_resolution( - height, - width, - max_h_div_w_ratio=hw_aspect_thr * aspect, - min_h_div_w_ratio=1 / hw_aspect_thr * aspect, - ) - if not is_pick: - cnt_resolution_mismatch += 1 - continue + predefine_num_frames = len(predefine_frame_indice) + # decord_vr = decord.VideoReader(path, ctx=decord.cpu(0), num_threads=1) + decord_vr = DecordDecoder(path) - # import ipdb;ipdb.set_trace() - i["num_frames"] = int(fps * duration) - # max 5.0 and min 1.0 are just thresholds to filter some videos which have suitable duration. - if i["num_frames"] > self.duration_threshold * ( - self.num_frames * fps / self.train_fps * self.speed_factor - ): # too long video is not suitable for this training stage (self.num_frames) - cnt_too_long += 1 - continue + frame_indices = self.get_actual_frame( + fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice + ) - # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) - frame_interval = fps / self.train_fps - start_frame_idx = 0 - frame_indices = np.arange(start_frame_idx, i["num_frames"], frame_interval).astype(int) - frame_indices = frame_indices[frame_indices < i["num_frames"]] + # video_data = decord_vr.get_batch(frame_indices).asnumpy() + # video_data = torch.from_numpy(video_data) + video_data = decord_vr.get_batch(frame_indices) + if video_data is not None: + if s_y is not None: + video_data = video_data[ + :, + s_y:e_y, + s_x:e_x, + :, + ] + else: + raise ValueError(f"Get video_data {video_data}") - # comment out it to enable dynamic frames training - if len(frame_indices) < self.num_frames and random.random() < self.drop_short_ratio: - cnt_too_short += 1 - continue + return video_data - # too long video will be temporal-crop randomly - if len(frame_indices) > self.num_frames: - begin_index, end_index = self.temporal_sample(len(frame_indices)) - frame_indices = frame_indices[begin_index:end_index] - # frame_indices = frame_indices[:self.num_frames] # head crop - # to find a suitable end_frame_idx, to ensure we do not need pad video - end_frame_idx = find_closest_y(len(frame_indices), vae_stride_t=4, model_ds_t=4) - if end_frame_idx == -1: # too short that can not be encoded exactly by videovae - if self.num_frames < 29: - logger.warning( - "The numbder of frames is less than 29, which is too short to be encoded by causal vae." - ) - cnt_too_short += 1 - continue - frame_indices = frame_indices[:end_frame_idx] - - i["sample_frame_index"] = frame_indices.tolist() - new_cap_list.append(i) - i["sample_num_frames"] = len(i["sample_frame_index"]) # will use in dataloader(group sampler) - sample_num_frames.append(i["sample_num_frames"]) - elif path.endswith(".jpg"): # image - cnt_img += 1 - new_cap_list.append(i) - i["sample_num_frames"] = 1 - sample_num_frames.append(i["sample_num_frames"]) - else: - raise NameError( - f"Unknown file extention {path.split('.')[-1]}, only support .mp4 for video and .jpg for image" - ) - # import ipdb;ipdb.set_trace() - logger.info( - f"no_cap: {cnt_no_cap}, too_long: {cnt_too_long}, too_short: {cnt_too_short}, " - f"no_resolution: {cnt_no_resolution}, resolution_mismatch: {cnt_resolution_mismatch}, " - f"Counter(sample_num_frames): {Counter(sample_num_frames)}, cnt_movie: {cnt_movie}, cnt_img: {cnt_img}, " - f"before filter: {len(cap_list)}, after filter: {len(new_cap_list)}" + def opencv_read(self, video_data): + path = video_data["path"] + predefine_frame_indice = video_data["sample_frame_index"] + start_frame_idx = video_data["start_frame_idx"] + clip_total_frames = video_data["num_frames"] + fps = video_data["fps"] + s_x, e_x, s_y, e_y = video_data.get("crop", [None, None, None, None]) + + predefine_num_frames = len(predefine_frame_indice) + cv2_vr = cv2.VideoCapture(path) + if not cv2_vr.isOpened(): + raise ValueError(f"can not open {path}") + frame_indices = self.get_actual_frame( + fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice ) - return new_cap_list, sample_num_frames - def decord_read(self, path, predefine_num_frames): - decord_vr = self.v_decoder(path) - total_frames = len(decord_vr) - fps = decord_vr.get_avg_fps() if decord_vr.get_avg_fps() > 0 else 30.0 - # import ipdb;ipdb.set_trace() + video_data = [] + for frame_idx in frame_indices: + cv2_vr.set(1, frame_idx) + _, frame = cv2_vr.read() + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + video_data.append(frame) # H, W, C + cv2_vr.release() + video_data = np.stack(video_data) # (T, H, W, C) + if s_y is not None: + video_data = video_data[:, s_y:e_y, s_x:e_x, :] + return video_data + + def get_actual_frame( + self, fps, start_frame_idx, clip_total_frames, path, predefine_num_frames, predefine_frame_indice + ): # resample in case high fps, such as 50/60/90/144 -> train_fps(e.g, 24) frame_interval = 1.0 if abs(fps - self.train_fps) < 0.1 else fps / self.train_fps - start_frame_idx = 0 - frame_indices = np.arange(start_frame_idx, total_frames, frame_interval).astype(int) - frame_indices = frame_indices[frame_indices < total_frames] - # import ipdb;ipdb.set_trace() + frame_indices = np.arange(start_frame_idx, start_frame_idx + clip_total_frames, frame_interval).astype(int) + frame_indices = frame_indices[frame_indices < start_frame_idx + clip_total_frames] + # speed up max_speed_factor = len(frame_indices) / self.num_frames if self.speed_factor > 1 and max_speed_factor > 1: @@ -416,22 +744,22 @@ def decord_read(self, path, predefine_num_frames): # frame_indices = frame_indices[:self.num_frames] # head crop # to find a suitable end_frame_idx, to ensure we do not need pad video - end_frame_idx = find_closest_y(len(frame_indices), vae_stride_t=4, model_ds_t=4) + end_frame_idx = find_closest_y(len(frame_indices), vae_stride_t=self.ae_stride_t, model_ds_t=self.sp_size) if end_frame_idx == -1: # too short that can not be encoded exactly by videovae raise IndexError( - f"video ({path}) has {total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})" + f"video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})" ) frame_indices = frame_indices[:end_frame_idx] if predefine_num_frames != len(frame_indices): raise ValueError( - f"predefine_num_frames ({predefine_num_frames}) is not equal with frame_indices ({len(frame_indices)})" + f"video ({path}) predefine_num_frames ({predefine_num_frames}) ({predefine_frame_indice}) is \ + not equal with frame_indices ({len(frame_indices)}) ({frame_indices})" ) if len(frame_indices) < self.num_frames and self.drop_short_ratio >= 1: raise IndexError( - f"video ({path}) has {total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})" + f"video ({path}) has {clip_total_frames} frames, but need to sample {len(frame_indices)} frames ({frame_indices})" ) - video_data = decord_vr.get_batch(frame_indices).asnumpy() # (T, H, W, C) - return video_data + return frame_indices def get_text_embed_file_path(self, item): file_path = item["path"] @@ -462,33 +790,3 @@ def parse_text_emb(self, npz): mask = mask[None, ...] return text_emb, mask # (1, L, D), (1, L) - - def read_jsons(self, data): - cap_lists = [] - with open(data, "r") as f: - folder_anno = [i.strip().split(",") for i in f.readlines() if len(i.strip()) > 0] - for item in folder_anno: - if len(item) == 2: - folder, anno = item - elif len(item) == 3: - folder, text_emb_folder, anno = item - else: - raise ValueError(f"Expect to have two or three paths, but got {len(item)} input paths") - if self.return_text_emb: - assert ( - len(item) == 3 - ), "When returning text embeddings, please give three paths: video folder, text_embed folder, annotation file" - with open(anno, "r") as f: - sub_list = json.load(f) - logger.info(f"Building {anno}...") - for i in range(len(sub_list)): - if self.return_text_emb: - text_embeds_paths = self.get_text_embed_file_path(sub_list[i]) - sub_list[i]["text_embed_path"] = [opj(text_emb_folder, tp) for tp in text_embeds_paths] - sub_list[i]["path"] = opj(folder, sub_list[i]["path"]) - cap_lists += sub_list - return cap_lists - - def get_cap_list(self): - cap_lists = self.read_jsons(self.data) - return cap_lists diff --git a/examples/opensora_pku/opensora/dataset/transform.py b/examples/opensora_pku/opensora/dataset/transform.py index b1edd43fa0..12734c8d85 100644 --- a/examples/opensora_pku/opensora/dataset/transform.py +++ b/examples/opensora_pku/opensora/dataset/transform.py @@ -5,8 +5,8 @@ import albumentations import ftfy -from bs4 import BeautifulSoup import numpy as np +from bs4 import BeautifulSoup __all__ = ["create_video_transforms", "t5_text_preprocessing"] @@ -91,20 +91,22 @@ def center_crop_th_tw(image, th, tw, top_crop, **kwargs): cropped_image = crop(image, i, j, new_h, new_w) return cropped_image -def resize(image, h, w, interpolation_mode): - resize_func = albumentations.Resize(h, w, interpolation = interpolation_mode) +def resize(image, h, w, interpolation_mode): + resize_func = albumentations.Resize(h, w, interpolation=interpolation_mode) return resize_func(image=image)["image"] + def get_params(h, w, stride): th, tw = h // stride * stride, w // stride * stride - + i = (h - th) // 2 j = (w - tw) // 2 return i, j, th, tw + def spatial_stride_crop_video(image, stride, **kwargs): """ Args: @@ -113,31 +115,33 @@ def spatial_stride_crop_video(image, stride, **kwargs): numpy array: cropped video clip by stride. size is (OH, OW, C) """ - h, w = image.shape[:2] + h, w = image.shape[:2] i, j, h, w = get_params(h, w, stride) return crop(image, i, j, h, w) + def maxhxw_resize(image, max_hxw, interpolation_mode, **kwargs): - """ - First use the h*w, - then resize to the specified size - Args: - image (numpy array): Video clip to be cropped. Size is (H, W, C) - Returns: - numpy array: scale resized video clip. - """ - h, w = image.shape[:2] - if h * w > max_hxw: - scale_factor = np.sqrt(max_hxw / (h * w)) - tr_h = int(h * scale_factor) - tr_w = int(w * scale_factor) - else: - tr_h = h - tr_w = w - if h == tr_h and w == tr_w: - return image - resize_image = resize(image, tr_h, tr_w, interpolation_mode) - return resize_image + """ + First use the h*w, + then resize to the specified size + Args: + image (numpy array): Video clip to be cropped. Size is (H, W, C) + Returns: + numpy array: scale resized video clip. + """ + h, w = image.shape[:2] + if h * w > max_hxw: + scale_factor = np.sqrt(max_hxw / (h * w)) + tr_h = int(h * scale_factor) + tr_w = int(w * scale_factor) + else: + tr_h = h + tr_w = w + if h == tr_h and w == tr_w: + return image + resize_image = resize(image, tr_h, tr_w, interpolation_mode) + return resize_image + # create text transform(preprocess) bad_punct_regex = re.compile( @@ -311,3 +315,179 @@ def __call__(self, t, h, w): if self.extra_1: truncate_t = truncate_t + 1 return 0, truncate_t + + +keywords = [ + " man ", + " woman ", + " person ", + " people ", + "human", + " individual ", + " child ", + " kid ", + " girl ", + " boy ", +] +keywords += [i[:-1] + "s " for i in keywords] + +masking_notices = [ + "Note: The faces in this image are blurred.", + "This image contains faces that have been pixelated.", + "Notice: Faces in this image are masked.", + "Please be aware that the faces in this image are obscured.", + "The faces in this image are hidden.", + "This is an image with blurred faces.", + "The faces in this image have been processed.", + "Attention: Faces in this image are not visible.", + "The faces in this image are partially blurred.", + "This image has masked faces.", + "Notice: The faces in this picture have been altered.", + "This is a picture with obscured faces.", + "The faces in this image are pixelated.", + "Please note, the faces in this image have been blurred.", + "The faces in this photo are hidden.", + "The faces in this picture have been masked.", + "Note: The faces in this picture are altered.", + "This is an image where faces are not clear.", + "Faces in this image have been obscured.", + "This picture contains masked faces.", + "The faces in this image are processed.", + "The faces in this picture are not visible.", + "Please be aware, the faces in this photo are pixelated.", + "The faces in this picture have been blurred.", +] + +webvid_watermark_notices = [ + "This video has a faint Shutterstock watermark in the center.", + "There is a slight Shutterstock watermark in the middle of this video.", + "The video contains a subtle Shutterstock watermark in the center.", + "This video features a light Shutterstock watermark at its center.", + "A faint Shutterstock watermark is present in the middle of this video.", + "There is a mild Shutterstock watermark at the center of this video.", + "This video has a slight Shutterstock watermark in the middle.", + "You can see a faint Shutterstock watermark in the center of this video.", + "A subtle Shutterstock watermark appears in the middle of this video.", + "This video includes a light Shutterstock watermark at its center.", +] + + +high_aesthetic_score_notices_video = [ + "This video has a high aesthetic quality.", + "The beauty of this video is exceptional.", + "This video scores high in aesthetic value.", + "With its harmonious colors and balanced composition.", + "This video ranks highly for aesthetic quality", + "The artistic quality of this video is excellent.", + "This video is rated high for beauty.", + "The aesthetic quality of this video is impressive.", + "This video has a top aesthetic score.", + "The visual appeal of this video is outstanding.", +] + +low_aesthetic_score_notices_video = [ + "This video has a low aesthetic quality.", + "The beauty of this video is minimal.", + "This video scores low in aesthetic appeal.", + "The aesthetic quality of this video is below average.", + "This video ranks low for beauty.", + "The artistic quality of this video is lacking.", + "This video has a low score for aesthetic value.", + "The visual appeal of this video is low.", + "This video is rated low for beauty.", + "The aesthetic quality of this video is poor.", +] + + +high_aesthetic_score_notices_image = [ + "This image has a high aesthetic quality.", + "The beauty of this image is exceptional", + "This photo scores high in aesthetic value.", + "With its harmonious colors and balanced composition.", + "This image ranks highly for aesthetic quality.", + "The artistic quality of this photo is excellent.", + "This image is rated high for beauty.", + "The aesthetic quality of this image is impressive.", + "This photo has a top aesthetic score.", + "The visual appeal of this image is outstanding.", +] + +low_aesthetic_score_notices_image = [ + "This image has a low aesthetic quality.", + "The beauty of this image is minimal.", + "This image scores low in aesthetic appeal.", + "The aesthetic quality of this image is below average.", + "This image ranks low for beauty.", + "The artistic quality of this image is lacking.", + "This image has a low score for aesthetic value.", + "The visual appeal of this image is low.", + "This image is rated low for beauty.", + "The aesthetic quality of this image is poor.", +] + +high_aesthetic_score_notices_image_human = [ + "High-quality image with visible human features and high aesthetic score.", + "Clear depiction of an individual in a high-quality image with top aesthetics.", + "High-resolution photo showcasing visible human details and high beauty rating.", + "Detailed, high-quality image with well-defined human subject and strong aesthetic appeal.", + "Sharp, high-quality portrait with clear human features and high aesthetic value.", + "High-quality image featuring a well-defined human presence and exceptional aesthetics.", + "Visible human details in a high-resolution photo with a high aesthetic score.", + "Clear, high-quality image with prominent human subject and superior aesthetic rating.", + "High-quality photo capturing a visible human with excellent aesthetics.", + "Detailed, high-quality image of a human with high visual appeal and aesthetic value.", +] + + +def calculate_statistics(data): + if len(data) == 0: + return None + data = np.array(data) + mean = np.mean(data) + variance = np.var(data) + std_dev = np.std(data) + minimum = np.min(data) + maximum = np.max(data) + + return {"mean": mean, "variance": variance, "std_dev": std_dev, "min": minimum, "max": maximum} + + +def maxhwresize(ori_height, ori_width, max_hxw): + if ori_height * ori_width > max_hxw: + scale_factor = np.sqrt(max_hxw / (ori_height * ori_width)) + new_height = int(ori_height * scale_factor) + new_width = int(ori_width * scale_factor) + else: + new_height = ori_height + new_width = ori_width + return new_height, new_width + + +def add_aesthetic_notice_video(caption, aesthetic_score): + if aesthetic_score <= 4.25: + notice = random.choice(low_aesthetic_score_notices_video) + return random.choice([caption + " " + notice, notice + " " + caption]) + if aesthetic_score >= 5.75: + notice = random.choice(high_aesthetic_score_notices_video) + return random.choice([caption + " " + notice, notice + " " + caption]) + return caption + + +def add_aesthetic_notice_image(caption, aesthetic_score): + if aesthetic_score <= 4.25: + notice = random.choice(low_aesthetic_score_notices_image) + return random.choice([caption + " " + notice, notice + " " + caption]) + if aesthetic_score >= 5.75: + notice = random.choice(high_aesthetic_score_notices_image) + return random.choice([caption + " " + notice, notice + " " + caption]) + return caption + + +def add_high_aesthetic_notice_image(caption): + notice = random.choice(high_aesthetic_score_notices_image) + return random.choice([caption + " " + notice, notice + " " + caption]) + + +def add_high_aesthetic_notice_image_human(caption): + notice = random.choice(high_aesthetic_score_notices_image_human) + return random.choice([caption + " " + notice, notice + " " + caption]) diff --git a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py index 93a4295c09..5c8e71398b 100644 --- a/examples/opensora_pku/opensora/train/train_t2v_diffusers.py +++ b/examples/opensora_pku/opensora/train/train_t2v_diffusers.py @@ -26,7 +26,7 @@ from opensora.npu_config import npu_config from opensora.train.commons import create_loss_scaler, parse_args from opensora.utils.callbacks import EMAEvalSwapCallback, PerfRecorderCallback -from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler +from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.ema import EMA from opensora.utils.message_utils import print_banner from opensora.utils.utils import get_precision, save_diffusers_json @@ -299,36 +299,22 @@ def main(args): initial_global_step_for_sampler = args.trained_data_global_step else: initial_global_step_for_sampler = 0 + total_batch_size = args.train_batch_size * device_num * args.gradient_accumulation_steps + total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size + args.total_batch_size = total_batch_size if args.max_hxw is not None and args.min_hxw is None: args.min_hxw = args.max_hxw // 4 train_dataset = getdataset(args, dataset_file=args.data) - sampler = ( - LengthGroupedBatchSampler( - args.train_batch_size, - world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), - lengths=train_dataset.lengths, - group_frame=args.group_frame, # v1.2 - group_resolution=args.group_resolution, # v1.2 - initial_global_step_for_sampler=initial_global_step_for_sampler, # TODO: use in v1.3 - group_data=args.group_data, # TODO: use in v1.3 - ) - if (args.group_frame or args.group_resolution) # v1.2 - else None # v1.2 - ) - collate_fn = Collate( + sampler = LengthGroupedSampler( args.train_batch_size, - args.group_frame, - args.group_resolution, - args.max_height, - args.max_width, - args.ae_stride, - args.ae_stride_t, - args.patch_size, - args.patch_size_t, - args.num_frames, - args.use_image_num, + world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), + gradient_accumulation_size=args.gradient_accumulation_steps, + initial_global_step=initial_global_step_for_sampler, + lengths=train_dataset.lengths, + group_data=args.group_data, ) + collate_fn = Collate(args) dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, @@ -353,17 +339,15 @@ def main(args): assert os.path.exists(args.val_data), f"validation dataset file must exist, but got {args.val_data}" print_banner("Validation dataset Loading...") val_dataset = getdataset(args, dataset_file=args.val_data) - sampler = ( - LengthGroupedBatchSampler( - args.val_batch_size, - world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), - lengths=val_dataset.lengths, - group_frame=args.group_frame, - group_resolution=args.group_resolution, - ) - if (args.group_frame or args.group_resolution) - else None + sampler = LengthGroupedSampler( + args.val_batch_size, + world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), + lengths=val_dataset.lengths, + gradient_accumulation_size=args.gradient_accumulation_steps, + initial_global_step=initial_global_step_for_sampler, + group_data=args.group_data, ) + collate_fn = Collate( args.val_batch_size, args.group_frame, @@ -674,8 +658,6 @@ def main(args): callback.append(ProfilerCallbackEpoch(2, 2, "./profile_data")) # Train! - total_batch_size = args.train_batch_size * device_num * args.gradient_accumulation_steps - total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size # 5. log and save config if rank_id == 0: @@ -761,7 +743,12 @@ def parse_t2v_train_args(parser): parser.add_argument("--hw_stride", type=int, default=32) parser.add_argument("--force_resolution", action="store_true") parser.add_argument("--trained_data_global_step", type=int, default=None) - parser.add_argument("--use_decord", action="store_true") + parser.add_argument( + "--use_decord", + type=str2bool, + default=True, + help="whether to use decord to load videos. If not, use opencv to load videos.", + ) # text encoder & vae & diffusion model parser.add_argument("--vae_fp32", action="store_true") diff --git a/examples/opensora_pku/opensora/utils/dataset_utils.py b/examples/opensora_pku/opensora/utils/dataset_utils.py index be4c472977..27caeb88eb 100644 --- a/examples/opensora_pku/opensora/utils/dataset_utils.py +++ b/examples/opensora_pku/opensora/utils/dataset_utils.py @@ -1,6 +1,6 @@ import math import random -from collections import Counter +from collections import Counter, defaultdict from typing import List, Optional import decord @@ -128,61 +128,77 @@ def pad_to_multiple(number, ds_stride): class Collate: - def __init__( - self, - batch_size, - group_frame, - group_resolution, - max_height, - max_width, - ae_stride, - ae_stride_t, - patch_size, - patch_size_t, - num_frames, - use_image_num, - ): - self.batch_size = batch_size - self.group_frame = group_frame - self.group_resolution = group_resolution + def __init__(self, args): + self.batch_size = args.train_batch_size + self.group_data = args.group_data + self.force_resolution = args.force_resolution - self.max_height = max_height - self.max_width = max_width - self.ae_stride = ae_stride + self.max_height = args.max_height + self.max_width = args.max_width + self.ae_stride = args.ae_stride - self.ae_stride_t = ae_stride_t + self.ae_stride_t = args.ae_stride_t self.ae_stride_thw = (self.ae_stride_t, self.ae_stride, self.ae_stride) - self.patch_size = patch_size - self.patch_size_t = patch_size_t + self.patch_size = args.patch_size + self.patch_size_t = args.patch_size_t - self.num_frames = num_frames - self.use_image_num = use_image_num + self.num_frames = args.num_frames self.max_thw = (self.num_frames, self.max_height, self.max_width) def package(self, batch): batch_tubes = [i["pixel_values"] for i in batch] # b [c t h w] - input_ids = [i["input_ids"] for i in batch] # b [1 l] - cond_mask = [i["cond_mask"] for i in batch] # b [1 l] - return batch_tubes, input_ids, cond_mask + input_ids_1 = [i["input_ids_1"] for i in batch] # b [1 l] + cond_mask_1 = [i["cond_mask_1"] for i in batch] # b [1 l] + input_ids_2 = [i["input_ids_2"] for i in batch] # b [1 l] + cond_mask_2 = [i["cond_mask_2"] for i in batch] # b [1 l] + assert all([i is None for i in input_ids_2]) or all([i is not None for i in input_ids_2]) + assert all([i is None for i in cond_mask_2]) or all([i is not None for i in cond_mask_2]) + if all([i is None for i in input_ids_2]): + input_ids_2 = None + if all([i is None for i in cond_mask_2]): + cond_mask_2 = None + return batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 def __call__(self, batch): - batch_tubes, input_ids, cond_mask = self.package(batch) + batch_tubes, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = self.package(batch) ds_stride = self.ae_stride * self.patch_size t_ds_stride = self.ae_stride_t * self.patch_size_t - pad_batch_tubes, attention_mask, input_ids, cond_mask = self.process( - batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, self.max_thw, self.ae_stride_thw + pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 = self.process( + batch_tubes, + input_ids_1, + cond_mask_1, + input_ids_2, + cond_mask_2, + t_ds_stride, + ds_stride, + self.max_thw, + self.ae_stride_thw, ) - # assert not np.any(np.isnan(pad_batch_tubes)), 'after pad_batch_tubes' - return pad_batch_tubes, attention_mask, input_ids, cond_mask + assert not np.any(np.isnan(pad_batch_tubes)), "after pad_batch_tubes" + if input_ids_2 is not None and cond_mask_2 is not None: + return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 + else: + return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1 - def process(self, batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, max_thw, ae_stride_thw): + def process( + self, + batch_tubes, + input_ids_1, + cond_mask_1, + input_ids_2, + cond_mask_2, + t_ds_stride, + ds_stride, + max_thw, + ae_stride_thw, + ): # pad to max multiple of ds_stride batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)] assert len(batch_input_size) == self.batch_size - if self.group_frame or self.group_resolution or self.batch_size == 1: # + if self.group_data or self.batch_size == 1: # len_each_batch = batch_input_size idx_length_dict = dict([*zip(list(range(self.batch_size)), len_each_batch)]) count_dict = Counter(len_each_batch) @@ -195,13 +211,25 @@ def process(self, batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, max random_select_batch = [ random.choice(candidate_batch) for _ in range(len(len_each_batch) - len(candidate_batch)) ] - # print(batch_input_size, idx_length_dict, count_dict, sorted_by_value, pick_length, candidate_batch, random_select_batch) + print( + batch_input_size, + idx_length_dict, + count_dict, + sorted_by_value, + pick_length, + candidate_batch, + random_select_batch, + ) pick_idx = candidate_batch + random_select_batch batch_tubes = [batch_tubes[i] for i in pick_idx] batch_input_size = [i.shape for i in batch_tubes] # [(c t h w), (c t h w)] - input_ids = [input_ids[i] for i in pick_idx] # b [1, l] - cond_mask = [cond_mask[i] for i in pick_idx] # b [1, l] + input_ids_1 = [input_ids_1[i] for i in pick_idx] # b [1, l] + cond_mask_1 = [cond_mask_1[i] for i in pick_idx] # b [1, l] + if input_ids_2 is not None: + input_ids_2 = [input_ids_2[i] for i in pick_idx] # b [1, l] + if cond_mask_2 is not None: + cond_mask_2 = [cond_mask_2[i] for i in pick_idx] # b [1, l] for i in range(1, self.batch_size): assert batch_input_size[0] == batch_input_size[i] @@ -217,7 +245,6 @@ def process(self, batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, max ) pad_max_t = pad_max_t + 1 - self.ae_stride_t each_pad_t_h_w = [[pad_max_t - i.shape[1], pad_max_h - i.shape[2], pad_max_w - i.shape[3]] for i in batch_tubes] - pad_batch_tubes = [ np.pad(im, [[0, 0]] * (len(im.shape) - 3) + [[0, pad_t], [0, pad_h], [0, pad_w]], constant_values=0) for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes) @@ -248,81 +275,61 @@ def process(self, batch_tubes, input_ids, cond_mask, t_ds_stride, ds_stride, max for i in valid_latent_size ] attention_mask = np.stack(attention_mask, axis=0) # b t h w - if self.batch_size == 1 or self.group_frame or self.group_resolution: + if self.batch_size == 1 or self.group_data: + if not np.all(attention_mask.astype(np.bool_)): + print( + batch_input_size, + (max_t, max_h, max_w), + (pad_max_t, pad_max_h, pad_max_w), + each_pad_t_h_w, + max_latent_size, + valid_latent_size, + ) assert np.all(attention_mask.astype(np.bool_)) - input_ids = np.stack(input_ids, axis=0) # b 1 l - cond_mask = np.stack(cond_mask, axis=0) # b 1 l - if input_ids.dtype == np.int64: - input_ids = input_ids.astype(np.int32) - if attention_mask.dtype == np.int64: - attention_mask = attention_mask.astype(np.int32) - if cond_mask.dtype == np.int64: - cond_mask = cond_mask.astype(np.int32) - return pad_batch_tubes, attention_mask, input_ids, cond_mask + input_ids_1 = np.stack(input_ids_1) # b 1 l + cond_mask_1 = np.stack(cond_mask_1) # b 1 l + input_ids_2 = np.stack(input_ids_2) if input_ids_2 is not None else input_ids_2 # b 1 l + cond_mask_2 = np.stack(cond_mask_2) if cond_mask_2 is not None else cond_mask_2 # b 1 l + return pad_batch_tubes, attention_mask, input_ids_1, cond_mask_1, input_ids_2, cond_mask_2 -def split_to_even_chunks(indices, lengths, num_chunks, batch_size): - """ - Split a list of indices into `chunks` chunks of roughly equal lengths. - """ - if len(indices) % num_chunks != 0: - chunks = [indices[i::num_chunks] for i in range(num_chunks)] - else: - num_indices_per_chunk = len(indices) // num_chunks - - chunks = [[] for _ in range(num_chunks)] - chunks_lengths = [0 for _ in range(num_chunks)] - for index in indices: - shortest_chunk = chunks_lengths.index(min(chunks_lengths)) - chunks[shortest_chunk].append(index) - chunks_lengths[shortest_chunk] += lengths[index] - if len(chunks[shortest_chunk]) == num_indices_per_chunk: - chunks_lengths[shortest_chunk] = float("inf") - # return chunks - - pad_chunks = [] - for idx, chunk in enumerate(chunks): - if batch_size != len(chunk): - assert batch_size > len(chunk) - if len(chunk) != 0: - chunk = chunk + [random.choice(chunk) for _ in range(batch_size - len(chunk))] - else: - chunk = random.choice(pad_chunks) - print(chunks[idx], "->", chunk) - pad_chunks.append(chunk) - return pad_chunks - - -def group_frame_fun(indices, lengths): - # sort by num_frames - indices.sort(key=lambda i: lengths[i], reverse=True) - return indices +def group_data_fun(lengths, generator=None): + # counter is decrease order + counter = Counter(lengths) # counter {'1x256x256': 3, ''} lengths ['1x256x256', '1x256x256', '1x256x256', ...] + grouped_indices = defaultdict(list) + for idx, item in enumerate(lengths): # group idx to a list + grouped_indices[item].append(idx) + grouped_indices = dict(grouped_indices) # {'1x256x256': [0, 1, 2], ...} + sorted_indices = [grouped_indices[item] for (item, _) in sorted(counter.items(), key=lambda x: x[1], reverse=True)] -def group_resolution_fun(indices): - raise NotImplementedError - return indices + # shuffle in each group + shuffle_sorted_indices = [] + for indice in sorted_indices: + shuffle_idx = generator.permutation(len(indice)).tolist() + shuffle_sorted_indices.extend([indice[idx] for idx in shuffle_idx]) + return shuffle_sorted_indices -def group_frame_and_resolution_fun(indices): - raise NotImplementedError - return indices - - -def last_group_frame_fun(shuffled_megabatches, lengths): +def last_group_data_fun(shuffled_megabatches, lengths): + # lengths ['1x256x256', '1x256x256', '1x256x256' ...] re_shuffled_megabatches = [] # print('shuffled_megabatches', len(shuffled_megabatches)) for i_megabatch, megabatch in enumerate(shuffled_megabatches): re_megabatch = [] for i_batch, batch in enumerate(megabatch): assert len(batch) != 0 - len_each_batch = [lengths[i] for i in batch] - idx_length_dict = dict([*zip(batch, len_each_batch)]) - count_dict = Counter(len_each_batch) + + len_each_batch = [lengths[i] for i in batch] # ['1x256x256', '1x256x256'] + idx_length_dict = dict([*zip(batch, len_each_batch)]) # {0: '1x256x256', 100: '1x256x256'} + count_dict = Counter(len_each_batch) # {'1x256x256': 2} or {'1x256x256': 1, '1x768x256': 1} if len(count_dict) != 1: - sorted_by_value = sorted(count_dict.items(), key=lambda item: item[1]) + sorted_by_value = sorted( + count_dict.items(), key=lambda item: item[1] + ) # {'1x256x256': 1, '1x768x256': 1} + # import ipdb;ipdb.set_trace() # print(batch, idx_length_dict, count_dict, sorted_by_value) pick_length = sorted_by_value[-1][0] # the highest frequency candidate_batch = [idx for idx, length in idx_length_dict.items() if length == pick_length] @@ -332,6 +339,12 @@ def last_group_frame_fun(shuffled_megabatches, lengths): # print(batch, idx_length_dict, count_dict, sorted_by_value, pick_length, candidate_batch, random_select_batch) batch = candidate_batch + random_select_batch # print(batch) + + for i in range(1, len(batch) - 1): + # if not lengths[batch[0]] == lengths[batch[i]]: + # print(batch, [lengths[i] for i in batch]) + # import ipdb;ipdb.set_trace() + assert lengths[batch[0]] == lengths[batch[i]] re_megabatch.append(batch) re_shuffled_megabatches.append(re_megabatch) @@ -343,48 +356,159 @@ def last_group_frame_fun(shuffled_megabatches, lengths): return re_shuffled_megabatches -def last_group_resolution_fun(indices): - raise NotImplementedError - return indices - +def split_to_even_chunks(megabatch, lengths, world_size, batch_size): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + # batch_size=2, world_size=2 + # [1, 2, 3, 4] -> [[1, 2], [3, 4]] + # [1, 2, 3] -> [[1, 2], [3]] + # [1, 2] -> [[1], [2]] + # [1] -> [[1], []] + chunks = [megabatch[i::world_size] for i in range(world_size)] -def last_group_frame_and_resolution_fun(indices): - raise NotImplementedError - return indices + pad_chunks = [] + for idx, chunk in enumerate(chunks): + if batch_size != len(chunk): + assert batch_size > len(chunk) + if len(chunk) != 0: # [[1, 2], [3]] -> [[1, 2], [3, 3]] + chunk = chunk + [random.choice(chunk) for _ in range(batch_size - len(chunk))] + else: + chunk = random.choice(pad_chunks) # [[1], []] -> [[1], [1]] + print(chunks[idx], "->", chunk) + pad_chunks.append(chunk) + return pad_chunks def get_length_grouped_indices( - lengths, batch_size, world_size, generator=None, group_frame=False, group_resolution=False, seed=42 + lengths, + batch_size, + world_size, + gradient_accumulation_size, + initial_global_step, + generator=None, + group_data=False, + seed=42, ): - # We need to use numpy for the random part as a distributed sampler will set the random seed if generator is None: generator = np.random.default_rng(seed) # every rank will generate a fixed order but random index + # print('lengths', lengths) + + if group_data: + indices = group_data_fun(lengths, generator) + else: + indices = generator.permutation(len(lengths)).tolist() + # print('indices', len(indices)) + + # print('sort indices', len(indices)) + # print('sort indices', indices) + # print('sort lengths', [lengths[i] for i in indices]) - indices = generator.permutation(len(lengths)).tolist() - if group_frame and not group_resolution: - indices = group_frame_fun(indices, lengths) - elif not group_frame and group_resolution: - indices = group_resolution_fun(indices) - elif group_frame and group_resolution: - indices = group_frame_and_resolution_fun(indices) megabatch_size = world_size * batch_size megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)] + # import ipdb;ipdb.set_trace() + # print('megabatches', len(megabatches)) + # print('\nmegabatches', megabatches) + # megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + # import ipdb;ipdb.set_trace() + # print('sort megabatches', len(megabatches)) + # megabatches_len = [[lengths[i] for i in megabatch] for megabatch in megabatches] + # print(f'\nrank {accelerator.process_index} sorted megabatches_len', megabatches_len[0], megabatches_len[1], megabatches_len[-2], megabatches_len[-1]) + # import ipdb;ipdb.set_trace() + megabatches = [split_to_even_chunks(megabatch, lengths, world_size, batch_size) for megabatch in megabatches] + # import ipdb;ipdb.set_trace() + # print('nsplit_to_even_chunks megabatches', len(megabatches)) + # print('\nsplit_to_even_chunks megabatches', megabatches) + # split_to_even_chunks_len = [[[lengths[i] for i in batch] for batch in megabatch] for megabatch in megabatches] + # print(f'\nrank {accelerator.process_index} split_to_even_chunks_len', split_to_even_chunks_len[0], + # split_to_even_chunks_len[1], split_to_even_chunks_len[-2], split_to_even_chunks_len[-1]) + # print('\nsplit_to_even_chunks len', split_to_even_chunks_len) + # return [i for megabatch in megabatches for batch in megabatch for i in batch] + + indices_mega = generator.permutation(len(megabatches)).tolist() + # print(f'rank {accelerator.process_index} seed {seed}, len(megabatches) {len(megabatches)}, indices_mega, {indices_mega[:50]}') + shuffled_megabatches = [megabatches[i] for i in indices_mega] + # shuffled_megabatches_len = [ + # [[lengths[i] for i in batch] for batch in megabatch] for megabatch in shuffled_megabatches + # ] + # print(f'\nrank {accelerator.process_index} sorted shuffled_megabatches_len', shuffled_megabatches_len[0], + # shuffled_megabatches_len[1], shuffled_megabatches_len[-2], shuffled_megabatches_len[-1]) + + # import ipdb;ipdb.set_trace() + # print('shuffled_megabatches', len(shuffled_megabatches)) + if group_data: + shuffled_megabatches = last_group_data_fun(shuffled_megabatches, lengths) + # group_shuffled_megabatches_len = [ + # [[lengths[i] for i in batch] for batch in megabatch] for megabatch in shuffled_megabatches + # ] + # print(f'\nrank {accelerator.process_index} group_shuffled_megabatches_len', group_shuffled_megabatches_len[0], + # group_shuffled_megabatches_len[1], group_shuffled_megabatches_len[-2], group_shuffled_megabatches_len[-1]) + + # import ipdb;ipdb.set_trace() + initial_global_step = initial_global_step * gradient_accumulation_size + # print('shuffled_megabatches', len(shuffled_megabatches)) + # print('have been trained idx:', len(shuffled_megabatches[:initial_global_step])) + # print('shuffled_megabatches[:10]', shuffled_megabatches[:10]) + # print('have been trained idx:', shuffled_megabatches[:initial_global_step]) + shuffled_megabatches = shuffled_megabatches[initial_global_step:] + print(f"Skip the data of {initial_global_step} step!") + # print('after shuffled_megabatches', len(shuffled_megabatches)) + # print('after shuffled_megabatches[:10]', shuffled_megabatches[:10]) + + # print('\nshuffled_megabatches', shuffled_megabatches) + # import ipdb;ipdb.set_trace() + # print('\nshuffled_megabatches len', [[i, lengths[i]] for megabatch in shuffled_megabatches for batch in megabatch for i in batch]) + # return [i for megabatch in shuffled_megabatches for batch in megabatch for i in batch] # return epoch indices in a list + return [batch for megabatch in shuffled_megabatches for batch in megabatch] # return batch indices (list of lists) + + +class LengthGroupedSampler: + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ - megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + def __init__( + self, + batch_size: int, + world_size: int, + gradient_accumulation_size: int, + initial_global_step: int, + lengths: Optional[List[int]] = None, + group_data=False, + generator=None, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") - megabatches = [split_to_even_chunks(megabatch, lengths, world_size, batch_size) for megabatch in megabatches] + self.batch_size = batch_size + self.world_size = world_size + self.initial_global_step = initial_global_step + self.gradient_accumulation_size = gradient_accumulation_size + self.lengths = lengths + self.group_data = group_data + self.generator = generator + # print('self.lengths, self.initial_global_step, self.batch_size, self.world_size, self.gradient_accumulation_size', + # len(self.lengths), self.initial_global_step, self.batch_size, self.world_size, self.gradient_accumulation_size) + + def __len__(self): + return ( + len(self.lengths) + - self.initial_global_step * self.batch_size * self.world_size * self.gradient_accumulation_size + ) - indices = generator.permutation(len(megabatches)).tolist() - shuffled_megabatches = [megabatches[i] for i in indices] - if group_frame and not group_resolution: - shuffled_megabatches = last_group_frame_fun(shuffled_megabatches, lengths) - elif not group_frame and group_resolution: - shuffled_megabatches = last_group_resolution_fun(shuffled_megabatches, indices) - elif group_frame and group_resolution: - shuffled_megabatches = last_group_frame_and_resolution_fun(shuffled_megabatches, indices) + def __iter__(self): + indices = get_length_grouped_indices( + self.lengths, + self.batch_size, + self.world_size, + self.gradient_accumulation_size, + self.initial_global_step, + group_data=self.group_data, + generator=self.generator, + ) - # return [i for megabatch in shuffled_megabatches for batch in megabatch for i in batch] - return [batch for megabatch in shuffled_megabatches for batch in megabatch] # return batch indices + return iter(indices) class LengthGroupedBatchSampler: @@ -401,7 +525,7 @@ def __init__( initial_global_step_for_sampler: int = 0, group_frame=False, group_resolution=False, - group_data = False, + group_data=False, generator=None, ): if lengths is None: diff --git a/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt b/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt index af4c65966e..2b47609591 100644 --- a/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt +++ b/examples/opensora_pku/scripts/train_data/video_data_v1_2.txt @@ -1 +1 @@ -datasets/Open-Sora-Plan-v1.2.0,datasets/Open-Sora-Plan-v1.2.0/mixkit_emb-len=512,datasets/Open-Sora-Plan-v1.2.0/v1.1.0_HQ_part1_Traffic_train.json +datasets/Open-Sora-Plan-v1.3.0/videos_16/,datasets/Open-Sora-Plan-v1.3.0/videos_16_emb-len=512,datasets/Open-Sora-Plan-v1.3.0/opendv_vid16.json diff --git a/examples/opensora_pku/tests/test_data.py b/examples/opensora_pku/tests/test_data.py index 601bbf8de8..b7efd0c9d2 100644 --- a/examples/opensora_pku/tests/test_data.py +++ b/examples/opensora_pku/tests/test_data.py @@ -13,7 +13,7 @@ from opensora.models.causalvideovae import ae_stride_config from opensora.models.diffusion import Diffusion_models from opensora.train.commons import parse_args -from opensora.utils.dataset_utils import Collate, LengthGroupedBatchSampler +from opensora.utils.dataset_utils import Collate, LengthGroupedSampler from opensora.utils.message_utils import print_banner from mindone.utils.config import str2bool @@ -47,31 +47,26 @@ def load_dataset_and_dataloader(args, device_num=1, rank_id=0): assert args.dataset == "t2v", "Support t2v dataset only." print_banner("Dataset Loading") # Setup data: + if args.trained_data_global_step is not None: + initial_global_step_for_sampler = args.trained_data_global_step + else: + initial_global_step_for_sampler = 0 + total_batch_size = args.train_batch_size * device_num * args.gradient_accumulation_steps + total_batch_size = total_batch_size // args.sp_size * args.train_sp_batch_size + args.total_batch_size = total_batch_size + if args.max_hxw is not None and args.min_hxw is None: + args.min_hxw = args.max_hxw // 4 + train_dataset = getdataset(args, dataset_file=args.data) - sampler = ( - LengthGroupedBatchSampler( - args.train_batch_size, - world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), - lengths=train_dataset.lengths, - group_frame=args.group_frame, - group_resolution=args.group_resolution, - ) - if (args.group_frame or args.group_resolution) - else None - ) - collate_fn = Collate( + sampler = LengthGroupedSampler( args.train_batch_size, - args.group_frame, - args.group_resolution, - args.max_height, - args.max_width, - args.ae_stride, - args.ae_stride_t, - args.patch_size, - args.patch_size_t, - args.num_frames, - args.use_image_num, + world_size=device_num if not get_sequence_parallel_state() else (device_num // hccl_info.world_size), + gradient_accumulation_size=args.gradient_accumulation_steps, + initial_global_step=initial_global_step_for_sampler, + lengths=train_dataset.lengths, + group_data=args.group_data, ) + collate_fn = Collate(args) dataloader = create_dataloader( train_dataset, batch_size=args.train_batch_size, @@ -101,7 +96,12 @@ def parse_t2v_train_args(parser): parser.add_argument("--hw_stride", type=int, default=32) parser.add_argument("--force_resolution", action="store_true") parser.add_argument("--trained_data_global_step", type=int, default=None) - parser.add_argument("--use_decord", action="store_true") + parser.add_argument( + "--use_decord", + type=str2bool, + default=True, + help="whether to use decord to load videos. If not, use opencv to load videos.", + ) # text encoder & vae & diffusion model parser.add_argument("--vae_fp32", action="store_true") diff --git a/examples/opensora_pku/tests/test_data.sh b/examples/opensora_pku/tests/test_data.sh index 72feb24a23..27e2d6aa8d 100644 --- a/examples/opensora_pku/tests/test_data.sh +++ b/examples/opensora_pku/tests/test_data.sh @@ -3,19 +3,20 @@ python tests/test_data.py \ --text_encoder_name_1 google/mt5-xxl \ --dataset t2v \ --num_frames 93 \ - --data "scripts/train_data/merge_data_mixkit.txt" \ + --data "scripts/train_data/video_data_v1_2.txt" \ --cache_dir "./" \ --ae WFVAEModel_D8_4x8x8 \ --ae_path "LanguageBind/Open-Sora-Plan-v1.3.0/vae" \ --sample_rate 1 \ - --max_height 352 \ + --max_height 640 \ --max_width 640 \ + --max_hxw 409600 \ --train_fps 16 \ - --force_resolution \ --interpolation_scale_t 1.0 \ --interpolation_scale_h 1.0 \ --interpolation_scale_w 1.0 \ - --train_batch_size=8 \ - --dataloader_num_workers 20 \ + --train_batch_size=1 \ + --dataloader_num_workers 8 \ --output_dir="test_data/" \ --model_max_length 512 \ + # --force_resolution \