Skip to content

Commit

Permalink
update new dataset v0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Dec 23, 2024
1 parent bfd56d3 commit de77902
Show file tree
Hide file tree
Showing 9 changed files with 1,103 additions and 509 deletions.
46 changes: 13 additions & 33 deletions examples/opensora_pku/opensora/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import AutoTokenizer

from .t2v_datasets import T2V_dataset
from .transform import TemporalRandomCrop, center_crop_th_tw, spatial_stride_crop_video, maxhxw_resize
from .transform import TemporalRandomCrop, center_crop_th_tw, maxhxw_resize, spatial_stride_crop_video


def getdataset(args, dataset_file):
Expand All @@ -18,16 +18,11 @@ def norm_func_albumentation(image, **kwargs):

mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC}
targets = {"image{}".format(i): "image" for i in range(args.num_frames)}
resize_topcrop = [
Lambda(
name="crop_topcrop",
image=partial(center_crop_th_tw, th=args.max_height, tw=args.max_width, top_crop=True),
p=1.0,
),
Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]),
]

if args.force_resolution:
assert (args.max_height is not None) and (args.max_width is not None), "set max_height and max_width for fixed resolution"
assert (args.max_height is not None) and (
args.max_width is not None
), "set max_height and max_width for fixed resolution"
resize = [
Lambda(
name="crop_centercrop",
Expand All @@ -36,7 +31,7 @@ def norm_func_albumentation(image, **kwargs):
),
Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]),
]
else: # dynamic resolution
else: # dynamic resolution
assert args.max_hxw is not None, "set max_hxw for dynamic resolution"
resize = [
Lambda(
Expand All @@ -46,7 +41,7 @@ def norm_func_albumentation(image, **kwargs):
),
Lambda(
name="spatial_stride_crop",
image=partial(spatial_stride_crop_video, stride=args.hw_stride), # default stride=32
image=partial(spatial_stride_crop_video, stride=args.hw_stride), # default stride=32
p=1.0,
),
]
Expand All @@ -55,35 +50,20 @@ def norm_func_albumentation(image, **kwargs):
[*resize, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)],
additional_targets=targets,
)
transform_topcrop = Compose(
[*resize_topcrop, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)],
additional_targets=targets,
)

tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir)
tokenizer_1 = AutoTokenizer.from_pretrained(args.text_encoder_name_1, cache_dir=args.cache_dir)
tokenizer_2 = None
if args.text_encoder_name_2 is not None:
tokenizer_2 = AutoTokenizer.from_pretrained(args.text_encoder_name_2, cache_dir=args.cache_dir)

if args.dataset == "t2v":
return T2V_dataset(
dataset_file,
num_frames=args.num_frames,
train_fps=args.train_fps,
use_image_num=args.use_image_num,
use_img_from_vid=args.use_img_from_vid,
model_max_length=args.model_max_length,
cfg=args.cfg,
speed_factor=args.speed_factor,
max_height=args.max_height,
max_width=args.max_width,
drop_short_ratio=args.drop_short_ratio,
dataloader_num_workers=args.dataloader_num_workers,
text_encoder_name=args.text_encoder_name_1, # TODO: update with 2nd text encoder
return_text_emb=args.text_embed_cache,
args,
transform=transform,
temporal_sample=temporal_sample,
tokenizer=tokenizer,
transform_topcrop=transform_topcrop,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
return_text_emb=args.text_embed_cache,
)
elif args.dataset == "inpaint" or args.dataset == "i2v":
raise NotImplementedError
Expand Down
36 changes: 30 additions & 6 deletions examples/opensora_pku/opensora/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def create_dataloader(
enable_modelarts=False,
collate_fn=None,
sampler=None,
batch_sampler=None,
):
datalen = len(dataset)

Expand All @@ -46,6 +47,7 @@ def create_dataloader(
shuffle=shuffle,
drop_last=drop_last,
sampler=sampler,
batch_sampler=batch_sampler,
)
dl = GeneratorDataset(
loader,
Expand All @@ -62,13 +64,24 @@ def create_dataloader(


def build_dataloader(
dataset, datalens, collate_fn, batch_size, device_num, rank_id=0, sampler=None, shuffle=True, drop_last=True
dataset,
datalens,
collate_fn,
batch_size,
device_num,
rank_id=0,
sampler=None,
batch_sampler=None,
shuffle=True,
drop_last=True,
):
if sampler is None:
sampler = BatchSampler(datalens, batch_size=batch_size, device_num=device_num, shuffle=shuffle)
if batch_sampler is None:
batch_sampler = BatchSampler(datalens, batch_size=batch_size, device_num=device_num, shuffle=shuffle)
loader = DataLoader(
dataset,
batch_sampler=sampler,
batch_size=batch_size,
sampler=sampler,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
device_num=device_num,
drop_last=drop_last,
Expand Down Expand Up @@ -107,14 +120,25 @@ def __len__(self):
class DataLoader:
"""DataLoader"""

def __init__(self, dataset, batch_sampler, collate_fn, device_num=1, drop_last=True, rank_id=0):
def __init__(
self,
dataset,
batch_size,
sampler=None,
batch_sampler=None,
collate_fn=None,
device_num=1,
drop_last=True,
rank_id=0,
):
self.dataset = dataset
self.sampler = sampler
self.batch_sampler = batch_sampler
self.collat_fn = collate_fn
self.device_num = device_num
self.rank_id = rank_id
self.drop_last = drop_last
self.batch_size = len(next(iter(self.batch_sampler)))
self.batch_size = batch_size

def __iter__(self):
self.step_index = 0
Expand Down
Loading

0 comments on commit de77902

Please sign in to comment.