From 8b85f91355f46b82eeb6a38acf4c146f7fc38ce9 Mon Sep 17 00:00:00 2001 From: fatih <34196005+fcakyon@users.noreply.github.com> Date: Sun, 27 Nov 2022 23:18:22 +0300 Subject: [PATCH] refactor video data loaders, fix some bugs (#22) * ignore mp4 avi zip files * update dependency versions * ignore onnx files * increase package version * fix a typo * refacator dataset loading * update code snippets in readme * reformat with isort * update workflows * ignore export and examples folders * clean code --- .github/workflows/ci.yml | 3 + .github/workflows/package_testing.yml | 3 + .gitignore | 8 +- README.md | 32 +- requirements.txt | 11 +- video_transformers/__init__.py | 2 +- video_transformers/data.py | 53 ++- .../pytorchvideo_wrapper/__init__.py | 0 .../pytorchvideo_wrapper/data/__init__.py | 0 .../data/labeled_video_dataset.py | 331 ++++++++++++++++++ .../data/labeled_video_paths.py | 73 ---- video_transformers/utils/extra.py | 4 +- video_transformers/utils/file.py | 13 + 13 files changed, 405 insertions(+), 128 deletions(-) create mode 100644 video_transformers/pytorchvideo_wrapper/__init__.py create mode 100644 video_transformers/pytorchvideo_wrapper/data/__init__.py create mode 100644 video_transformers/pytorchvideo_wrapper/data/labeled_video_dataset.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8ede845..fee83bd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,6 +64,9 @@ jobs: if: matrix.operating-system == 'macos-latest' run: pip install torch==${{ matrix.torch-version }} + - name: Install Pytorchvideo from main branch + run: pip install git+https://github.com/facebookresearch/pytorchvideo.git + - name: Lint with flake8, black and isort run: | pip install .[dev] diff --git a/.github/workflows/package_testing.yml b/.github/workflows/package_testing.yml index 83ef596..388ce2c 100644 --- a/.github/workflows/package_testing.yml +++ b/.github/workflows/package_testing.yml @@ -63,6 +63,9 @@ jobs: if: matrix.operating-system == 'macos-latest' run: pip install torch==${{ matrix.torch-version }} + - name: Install Pytorchvideo from main branch + run: pip install git+https://github.com/facebookresearch/pytorchvideo.git + - name: Install latest video-transformers package run: > pip install --upgrade --force-reinstall video-transformers[test] diff --git a/.gitignore b/.gitignore index 09924eb..ed81d35 100644 --- a/.gitignore +++ b/.gitignore @@ -131,4 +131,10 @@ dmypy.json # extra .vscode .neptune -runs/ \ No newline at end of file +runs/ +*.mp4 +*.avi +*.zip +*.onnx +exports/ +examples/ \ No newline at end of file diff --git a/README.md b/README.md index 707cab8..0564603 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,12 @@ and supports: conda install pytorch=1.11.0 torchvision=0.12.0 cudatoolkit=11.3 -c pytorch ``` +- Install pytorchvideo from main branch: + +```bash +pip install git+https://github.com/facebookresearch/pytorchvideo.git +``` + - Install `video-transformers`: ```bash @@ -87,6 +93,7 @@ from video_transformers.data import VideoDataModule from video_transformers.heads import LinearHead from video_transformers.necks import TransformerNeck from video_transformers.trainer import trainer_factory +from video_transformers.utils.file import download_ucf6 backbone = TimeDistributed(TransformersBackbone("microsoft/cvt-13", num_unfrozen_stages=0)) neck = TransformerNeck( @@ -96,13 +103,11 @@ neck = TransformerNeck( transformer_enc_num_layers=2, dropout_p=0.1, ) -optimizer = AdamW(model.parameters(), lr=1e-4) +download_ucf6("./") datamodule = VideoDataModule( - train_root=".../ucf6/train", - val_root=".../ucf6/val", - clip_duration=2, - train_dataset_multiplier=1, + train_root="ucf6/train", + val_root="ucf6/val", batch_size=4, num_workers=4, num_timesteps=8, @@ -110,14 +115,16 @@ datamodule = VideoDataModule( preprocess_clip_duration=1, preprocess_means=backbone.mean, preprocess_stds=backbone.std, - preprocess_min_short_side_scale=256, - preprocess_max_short_side_scale=320, + preprocess_min_short_side=256, + preprocess_max_short_side=320, preprocess_horizontal_flip_p=0.5, ) head = LinearHead(hidden_size=neck.num_features, num_classes=datamodule.num_classes) model = VideoModel(backbone, head, neck) +optimizer = AdamW(model.parameters(), lr=1e-4) + Trainer = trainer_factory("single_label_classification") trainer = Trainer( datamodule, @@ -139,14 +146,15 @@ from video_transformers.data import VideoDataModule from video_transformers.heads import LinearHead from video_transformers.necks import GRUNeck from video_transformers.trainer import trainer_factory +from video_transformers.utils.file import download_ucf6 backbone = TimeDistributed(TimmBackbone("mobilevitv2_100", num_unfrozen_stages=0)) neck = GRUNeck(num_features=backbone.num_features, hidden_size=128, num_layers=2, return_last=True) +download_ucf6("./") datamodule = VideoDataModule( - train_root=".../ucf6/train", - val_root=".../ucf6/val", - train_dataset_multiplier=1, + train_root="ucf6/train", + val_root="ucf6/val", batch_size=4, num_workers=4, num_timesteps=8, @@ -154,8 +162,8 @@ datamodule = VideoDataModule( preprocess_clip_duration=1, preprocess_means=backbone.mean, preprocess_stds=backbone.std, - preprocess_min_short_side_scale=256, - preprocess_max_short_side_scale=320, + preprocess_min_short_side=256, + preprocess_max_short_side=320, preprocess_horizontal_flip_p=0.5, ) diff --git a/requirements.txt b/requirements.txt index 85eabb9..b8db1bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,12 @@ -accelerate>=0.12.0 -evaluate>=0.2.2 -transformers>=4.23.1 -timm>=0.6.7 +accelerate>=0.14.0,<0.15.0 +evaluate>=0.3.0,<0.4.0 +transformers>=4.24.0,<4.25.0 +timm>=0.6.12,<0.7.0 click==8.0.4 -pytorchvideo balanced-loss scikit-learn tensorboard opencv-python gradio>=3.1.6 -huggingface-hub>=0.10.1 +huggingface-hub>=0.11.0,<0.12.0 importlib-metadata>=1.1.0,<4.3;python_version<'3.8' diff --git a/video_transformers/__init__.py b/video_transformers/__init__.py index 0228418..1429fd3 100644 --- a/video_transformers/__init__.py +++ b/video_transformers/__init__.py @@ -3,4 +3,4 @@ from video_transformers.auto.neck import AutoNeck from video_transformers.modeling import TimeDistributed, VideoModel -__version__ = "0.0.6" +__version__ = "0.0.7" diff --git a/video_transformers/data.py b/video_transformers/data.py index f6c03a5..c2019b4 100644 --- a/video_transformers/data.py +++ b/video_transformers/data.py @@ -14,6 +14,7 @@ from torch.utils.data import DataLoader from torchvision.transforms import CenterCrop, Compose, Lambda, RandomCrop, RandomHorizontalFlip +from video_transformers.pytorchvideo_wrapper.data.labeled_video_dataset import labeled_video_dataset from video_transformers.pytorchvideo_wrapper.data.labeled_video_paths import LabeledVideoDataset, LabeledVideoPaths from video_transformers.utils.extra import class_to_config @@ -53,8 +54,8 @@ def __init__( input_size: model input isze means: mean of the video clip stds: standard deviation of the video clip - min_short_side_scale: minimum short side of the video clip - max_short_side_scale: maximum short side of the video clip + min_short_side: minimum short side of the video clip + max_short_side: maximum short side of the video clip horizontal_flip_p: probability of horizontal flip clip_duration: duration of each video clip @@ -77,10 +78,13 @@ def __init__( self.clip_duration = clip_duration # Transforms applied to train dataset. + def normalize_func(x): + return x / 255.0 + self.train_video_transform = Compose( [ UniformTemporalSubsample(self.num_timesteps), - Lambda(lambda x: x / 255.0), + Lambda(normalize_func), Normalize(self.means, self.stds), RandomShortSideScale( min_size=self.min_short_side, @@ -97,7 +101,7 @@ def __init__( self.val_video_transform = Compose( [ UniformTemporalSubsample(self.num_timesteps), - Lambda(lambda x: x / 255.0), + Lambda(normalize_func), Normalize(self.means, self.stds), ShortSideScale(self.min_short_side), CenterCrop(self.input_size), @@ -112,7 +116,6 @@ def __init__( train_root: str, val_root: str, test_root: str = None, - train_dataset_multiplier: int = 1, batch_size: int = 4, num_workers: int = 4, num_timesteps: int = 8, @@ -158,8 +161,6 @@ def __init__( Path to kinetics formatted train folder. clip_duration: float Duration of sampled clip for each video. - train_dataset_multiplier: int - Multipler for number of of random training data samples. batch_size: int Batch size for training and validation. num_workers: int @@ -196,7 +197,6 @@ def __init__( self.train_root = train_root self.val_root = val_root self.test_root = test_root if test_root is not None else val_root - self.train_dataset_multiplier = train_dataset_multiplier self.labels = None self.train_dataloader = self._get_train_dataloader() @@ -212,18 +212,13 @@ def config(self) -> Dict: return class_to_config(self, ignored_attrs=("config", "train_root", "val_root", "test_root")) def _get_train_dataloader(self): - labeled_video_paths = LabeledVideoPaths.from_path(self.train_root) - labeled_video_paths.path_prefix = "" - video_sampler = torch.utils.data.RandomSampler clip_sampler = pytorchvideo.data.make_clip_sampler("random", self.preprocessor_config["clip_duration"]) - dataset = LabeledVideoDataset( - labeled_video_paths, - clip_sampler, - video_sampler, - self.preprocessor.train_transform, + dataset = labeled_video_dataset( + data_path=self.train_root, + clip_sampler=clip_sampler, + transform=self.preprocessor.train_transform, decode_audio=False, decoder="pyav", - dataset_multiplier=self.train_dataset_multiplier, ) self.labels = dataset.labels return DataLoader( @@ -234,18 +229,14 @@ def _get_train_dataloader(self): ) def _get_val_dataloader(self): - labeled_video_paths = LabeledVideoPaths.from_path(self.val_root) - labeled_video_paths.path_prefix = "" - video_sampler = torch.utils.data.SequentialSampler clip_sampler = pytorchvideo.data.clip_sampling.UniformClipSamplerTruncateFromStart( clip_duration=self.preprocessor_config["clip_duration"], truncation_duration=self.preprocessor_config["clip_duration"], ) - dataset = LabeledVideoDataset( - labeled_video_paths, - clip_sampler, - video_sampler, - self.preprocessor.val_transform, + dataset = labeled_video_dataset( + data_path=self.val_root, + clip_sampler=clip_sampler, + transform=self.preprocessor.val_transform, decode_audio=False, decoder="pyav", ) @@ -257,18 +248,14 @@ def _get_val_dataloader(self): ) def _get_test_dataloader(self): - labeled_video_paths = LabeledVideoPaths.from_path(self.test_root) - labeled_video_paths.path_prefix = "" - video_sampler = torch.utils.data.SequentialSampler clip_sampler = pytorchvideo.data.clip_sampling.UniformClipSamplerTruncateFromStart( clip_duration=self.preprocessor_config["clip_duration"], truncation_duration=self.preprocessor_config["clip_duration"], ) - dataset = LabeledVideoDataset( - labeled_video_paths, - clip_sampler, - video_sampler, - self.preprocessor.val_transform, + dataset = labeled_video_dataset( + data_path=self.test_root, + clip_sampler=clip_sampler, + transform=self.preprocessor.val_transform, decode_audio=False, decoder="pyav", ) diff --git a/video_transformers/pytorchvideo_wrapper/__init__.py b/video_transformers/pytorchvideo_wrapper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/video_transformers/pytorchvideo_wrapper/data/__init__.py b/video_transformers/pytorchvideo_wrapper/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/video_transformers/pytorchvideo_wrapper/data/labeled_video_dataset.py b/video_transformers/pytorchvideo_wrapper/data/labeled_video_dataset.py new file mode 100644 index 0000000..791e32d --- /dev/null +++ b/video_transformers/pytorchvideo_wrapper/data/labeled_video_dataset.py @@ -0,0 +1,331 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +# Modified from https://github.com/facebookresearch/pytorchvideo/blob/9180d6d57cb9e15100ec9df3a049cf8d1121b302/pytorchvideo/data/labeled_video_dataset.py + +from __future__ import annotations + +import gc +import logging +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Tuple, Type + +import torch.utils.data +from pytorchvideo.data.clip_sampling import ( + ClipSampler, + ConstantClipsPerVideoSampler, + RandomClipSampler, + UniformClipSampler, +) + +try: # this class is not available in the pypi version of pytorchvideo package + from pytorchvideo.data.clip_sampling import UniformClipSamplerTruncateFromStart +except ImportError: + UniformClipSamplerTruncateFromStart = None + +from pytorchvideo.data.utils import MultiProcessSampler +from pytorchvideo.data.video import VideoPathHandler + +from video_transformers.pytorchvideo_wrapper.data.labeled_video_paths import LabeledVideoPaths + +logger = logging.getLogger(__name__) + + +class LabeledVideoDataset(torch.utils.data.IterableDataset): + """ + LabeledVideoDataset handles the storage, loading, decoding and clip sampling for a + video dataset. It assumes each video is stored as either an encoded video + (e.g. mp4, avi) or a frame video (e.g. a folder of jpg, or png) + """ + + _MAX_CONSECUTIVE_FAILURES = 10 + + def __init__( + self, + labeled_video_paths: List[Tuple[str, Optional[dict]]], + clip_sampler: ClipSampler, + video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler, + transform: Optional[Callable[[dict], Any]] = None, + decode_audio: bool = True, + decode_video: bool = True, + decoder: str = "pyav", + ) -> None: + """ + Args: + labeled_video_paths (List[Tuple[str, Optional[dict]]]): List containing + video file paths and associated labels. If video paths are a folder + it's interpreted as a frame video, otherwise it must be an encoded + video. + + clip_sampler (ClipSampler): Defines how clips should be sampled from each + video. See the clip sampling documentation for more information. + + video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal + video container. This defines the order videos are decoded and, + if necessary, the distributed split. + + transform (Callable): This callable is evaluated on the clip output before + the clip is returned. It can be used for user defined preprocessing and + augmentations on the clips. The clip output format is described in __next__(). + + decode_audio (bool): If True, decode audio from video. + + decode_video (bool): If True, decode video frames from a video container. + + decoder (str): Defines what type of decoder used to decode a video. Not used for + frame videos. + """ + self._decode_audio = decode_audio + self._decode_video = decode_video + self._transform = transform + self._clip_sampler = clip_sampler + self._labeled_videos = labeled_video_paths + self._decoder = decoder + + # If a RandomSampler is used we need to pass in a custom random generator that + # ensures all PyTorch multiprocess workers have the same random seed. + self._video_random_generator = None + if video_sampler == torch.utils.data.RandomSampler: + self._video_random_generator = torch.Generator() + self._video_sampler = video_sampler(self._labeled_videos, generator=self._video_random_generator) + else: + self._video_sampler = video_sampler(self._labeled_videos) + + self._video_sampler_iter = None # Initialized on first call to self.__next__() + + # Depending on the clip sampler type, we may want to sample multiple clips + # from one video. In that case, we keep the store video, label and previous sampled + # clip time in these variables. + self._loaded_video_label = None + self._loaded_clip = None + self._last_clip_end_time = None + self.video_path_handler = VideoPathHandler() + + @property + def video_sampler(self): + """ + Returns: + The video sampler that defines video sample order. Note that you'll need to + use this property to set the epoch for a torch.utils.data.DistributedSampler. + """ + return self._video_sampler + + def __len__(self): + """ + Returns: + Number of videos in dataset. + """ + if isinstance( + self._clip_sampler, (RandomClipSampler, ConstantClipsPerVideoSampler, UniformClipSamplerTruncateFromStart) + ): + return len(self.video_sampler) + elif isinstance(self._clip_sampler, UniformClipSampler): + return None + + def __next__(self) -> dict: + """ + Retrieves the next clip based on the clip sampling strategy and video sampler. + + Returns: + A dictionary with the following format. + + .. code-block:: text + + { + 'video': , + 'label': , + 'video_label': + 'video_index': , + 'clip_index': , + 'aug_index': , + 'clip_start_index': , + } + """ + if not self._video_sampler_iter: + # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned. + self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler)) + + for i_try in range(self._MAX_CONSECUTIVE_FAILURES): + # Reuse previously stored video if there are still clips to be sampled from + # the last loaded video. + if self._loaded_video_label: + video, info_dict, video_index = self._loaded_video_label + else: + video_index = next(self._video_sampler_iter) + try: + video_path, info_dict = self._labeled_videos[video_index] + video = self.video_path_handler.video_from_path( + video_path, + decode_audio=self._decode_audio, + decode_video=self._decode_video, + decoder=self._decoder, + ) + self._loaded_video_label = (video, info_dict, video_index) + except Exception as e: + logger.debug( + "Failed to load video with error: {}; trial {}".format( + e, + i_try, + ) + ) + logger.exception("Video load exception") + continue + + ( + clip_start, + clip_end, + clip_index, + aug_index, + is_last_clip, + ) = self._clip_sampler(self._last_clip_end_time, video.duration, info_dict) + + if isinstance(clip_start, list): # multi-clip in each sample + + # Only load the clips once and reuse previously stored clips if there are multiple + # views for augmentations to perform on the same clips. + if aug_index[0] == 0: + self._loaded_clip = {} + loaded_clip_list = [] + for i in range(len(clip_start)): + clip_dict = video.get_clip(clip_start[i], clip_end[i]) + if clip_dict is None or clip_dict["video"] is None: + self._loaded_clip = None + break + loaded_clip_list.append(clip_dict) + + if self._loaded_clip is not None: + for key in loaded_clip_list[0].keys(): + self._loaded_clip[key] = [x[key] for x in loaded_clip_list] + + else: # single clip case + + # Only load the clip once and reuse previously stored clip if there are multiple + # views for augmentations to perform on the same clip. + if aug_index == 0: + self._loaded_clip = video.get_clip(clip_start, clip_end) + + self._last_clip_end_time = clip_end + + video_is_null = self._loaded_clip is None or self._loaded_clip["video"] is None + if (is_last_clip[-1] if isinstance(is_last_clip, list) else is_last_clip) or video_is_null: + # Close the loaded encoded video and reset the last sampled clip time ready + # to sample a new video on the next iteration. + self._loaded_video_label[0].close() + self._loaded_video_label = None + self._last_clip_end_time = None + self._clip_sampler.reset() + + # Force garbage collection to release video container immediately + # otherwise memory can spike. + gc.collect() + + if video_is_null: + logger.debug("Failed to load clip {}; trial {}".format(video.name, i_try)) + continue + + frames = self._loaded_clip["video"] + audio_samples = self._loaded_clip["audio"] + sample_dict = { + "video": frames, + "video_name": video.name, + "video_index": video_index, + "clip_index": clip_index, + "aug_index": aug_index, + **info_dict, + **({"audio": audio_samples} if audio_samples is not None else {}), + } + if self._transform is not None: + sample_dict = self._transform(sample_dict) + + # User can force dataset to continue by returning None in transform. + if sample_dict is None: + continue + + return sample_dict + else: + raise RuntimeError(f"Failed to load video after {self._MAX_CONSECUTIVE_FAILURES} retries.") + + def __iter__(self): + self._video_sampler_iter = None # Reset video sampler + + # If we're in a PyTorch DataLoader multiprocessing context, we need to use the + # same seed for each worker's RandomSampler generator. The workers at each + # __iter__ call are created from the unique value: worker_info.seed - worker_info.id, + # which we can use for this seed. + worker_info = torch.utils.data.get_worker_info() + if self._video_random_generator is not None and worker_info is not None: + base_seed = worker_info.seed - worker_info.id + self._video_random_generator.manual_seed(base_seed) + + return self + + @property + def videos_per_class(self): + class_id_to_number = defaultdict(int) + for _labeled_video in self._labeled_videos: + label_info = _labeled_video[1] + class_id_to_number[label_info["label"]] += 1 + class_ids = list(class_id_to_number.keys()) + return [class_id_to_number[class_id] for class_id in range(max(class_ids) + 1)] + + @property + def labels(self): + """ + Returns: + The list of class labels. + """ + return self._labeled_videos.labels + + +def labeled_video_dataset( + data_path: str, + clip_sampler: ClipSampler, + video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler, + transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + video_path_prefix: str = "", + decode_audio: bool = True, + decoder: str = "pyav", +) -> LabeledVideoDataset: + """ + A helper function to create ``LabeledVideoDataset`` object for Ucf101 and Kinetics datasets. + + Args: + data_path (str): Path to the data. The path type defines how the data + should be read: + + * For a file path, the file is read and each line is parsed into a + video path and label. + * For a directory, the directory structure defines the classes + (i.e. each subdirectory is a class). + + clip_sampler (ClipSampler): Defines how clips should be sampled from each + video. See the clip sampling documentation for more information. + + video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal + video container. This defines the order videos are decoded and, + if necessary, the distributed split. + + transform (Callable): This callable is evaluated on the clip output before + the clip is returned. It can be used for user defined preprocessing and + augmentations to the clips. See the ``LabeledVideoDataset`` class for clip + output format. + + video_path_prefix (str): Path to root directory with the videos that are + loaded in ``LabeledVideoDataset``. All the video paths before loading + are prefixed with this path. + + decode_audio (bool): If True, also decode audio from video. + + decoder (str): Defines what type of decoder used to decode a video. + + """ + labeled_video_paths = LabeledVideoPaths.from_path(data_path) + labeled_video_paths.path_prefix = video_path_prefix + dataset = LabeledVideoDataset( + labeled_video_paths, + clip_sampler, + video_sampler, + transform, + decode_audio=decode_audio, + decoder=decoder, + ) + return dataset diff --git a/video_transformers/pytorchvideo_wrapper/data/labeled_video_paths.py b/video_transformers/pytorchvideo_wrapper/data/labeled_video_paths.py index 8415c07..0d041ed 100644 --- a/video_transformers/pytorchvideo_wrapper/data/labeled_video_paths.py +++ b/video_transformers/pytorchvideo_wrapper/data/labeled_video_paths.py @@ -6,7 +6,6 @@ import logging import os import pathlib -import zipfile from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast @@ -18,8 +17,6 @@ from pytorchvideo.data.video import VideoPathHandler from torchvision.datasets.folder import find_classes, has_file_allowed_extension, make_dataset -from video_transformers.utils.file import download_file - logger = logging.getLogger(__name__) @@ -292,73 +289,3 @@ def __len__(self): return self._len else: raise ValueError(f"Length calculation not implemented for sampler: {type(self.video_sampler)}.") - - -def labeled_video_dataset( - data_path: str, - clip_sampler: ClipSampler, - video_sampler: Type[torch.utils.data.Sampler] = None, - transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, - video_path_prefix: str = "", - decode_audio: bool = True, - decoder: str = "pyav", - dataset_multiplier: int = 1, -) -> LabeledVideoDataset: - """ - A helper function to create ``LabeledVideoDataset`` object for Ucf101 and Kinetics datasets. - - Args: - data_path (str): Path to the data. The path type defines how the data - should be read: - - * For a file path, the file is read and each line is parsed into a - video path and label. - * For a directory, the directory structure defines the classes - (i.e. each subdirectory is a class). - - clip_sampler (ClipSampler): Defines how clips should be sampled from each - video. See the clip sampling documentation for more information. - - video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal - video container. This defines the order videos are decoded and, - if necessary, the distributed split. - - transform (Callable): This callable is evaluated on the clip output before - the clip is returned. It can be used for user defined preprocessing and - augmentations to the clips. See the ``LabeledVideoDataset`` class for clip - output format. - - video_path_prefix (str): Path to root directory with the videos that are - loaded in ``LabeledVideoDataset``. All the video paths before loading - are prefixed with this path. - - decode_audio (bool): If True, also decode audio from video. - - decoder (str): Defines what type of decoder used to decode a video. - - """ - labeled_video_paths = LabeledVideoPaths.from_path(data_path) - labeled_video_paths.path_prefix = video_path_prefix - video_sampler = torch.utils.data.RandomSampler( - replacement=True, num_samples=len(labeled_video_paths) * dataset_multiplier - ) - dataset = LabeledVideoDataset( - labeled_video_paths, - clip_sampler, - video_sampler, - transform, - decode_audio=decode_audio, - decoder=decoder, - ) - return dataset - - -def download_ucf6(data_path: str): - """ - Downloads the ucf6 dataset to the given path. - """ - download_url = "https://github.com/fcakyon/video-transformers/releases/download/0.0.0/ucf6.zip" - download_path = os.path.join(data_path, "ucf6.zip") - download_file(download_url, download_path) - with zipfile.ZipFile(download_path, "r") as zip_ref: - zip_ref.extractall(data_path) diff --git a/video_transformers/utils/extra.py b/video_transformers/utils/extra.py index 19b898e..663c920 100644 --- a/video_transformers/utils/extra.py +++ b/video_transformers/utils/extra.py @@ -51,8 +51,8 @@ def scheduler_to_config(scheduler): if isinstance(main_scheduler, torch.optim.lr_scheduler.LinearLR): return { "optimizer": { - "name": scheduler.optimizer.__class__.__name__, - "defaults": scheduler.optimizer.defaults, + "name": main_scheduler.optimizer.__class__.__name__, + "defaults": main_scheduler.optimizer.defaults, }, "warmup_scheduler": { "class": "torch.optim.lr_scheduler.LinearLR", diff --git a/video_transformers/utils/file.py b/video_transformers/utils/file.py index ce0cdc7..42d2857 100644 --- a/video_transformers/utils/file.py +++ b/video_transformers/utils/file.py @@ -2,6 +2,7 @@ import os import re import urllib.request +import zipfile from pathlib import Path @@ -22,8 +23,20 @@ def download_file(url: str, download_path: str): """ Downloads a file from the given url to the given path. """ + Path(download_path).parent.mkdir(parents=True, exist_ok=True) if not os.path.exists(download_path): print(f"Downloading {url} to {download_path}") urllib.request.urlretrieve(url, download_path) else: print(f"{download_path} already exists. Skipping download.") + + +def download_ucf6(download_folder_path: str): + """ + Downloads the ucf6 dataset to the given folder. + """ + download_url = "https://github.com/fcakyon/video-transformers/releases/download/0.0.2/ucf6.zip" + download_path = Path(download_folder_path) / "ucf6.zip" + download_file(download_url, download_path) + with zipfile.ZipFile(download_path, "r") as zip_ref: + zip_ref.extractall(download_folder_path)