diff --git a/pyproject.toml b/pyproject.toml index 405d16f07..13745ed33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -273,7 +273,6 @@ exclude = '''(?x)( tests/core/test_Core.py | tests/data/test_multi_view_collate.py | tests/data/test_data_collate.py | - tests/data/test_VideoDataset.py | tests/data/test_LightlySubset.py | tests/data/test_LightlyDataset.py | tests/embedding/test_callbacks.py | diff --git a/tests/data/test_VideoDataset.py b/tests/data/test_VideoDataset.py index 2cda10011..d5590851b 100644 --- a/tests/data/test_VideoDataset.py +++ b/tests/data/test_VideoDataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import io import os @@ -5,10 +7,9 @@ import tempfile import unittest from fractions import Fraction -from typing import List +from typing import Any from unittest import mock -import numpy as np import PIL import torch import torchvision @@ -36,17 +37,17 @@ PYAV_AVAILABLE or VIDEO_READER_AVAILABLE, "No video backend available" ) class TestVideoDataset(unittest.TestCase): - def tearDown(self): + def tearDown(self) -> None: # Make sure to set the default backend to not interfere with other tests. torchvision.set_video_backend(DEFAULT_BACKEND) - def ensure_dir(self, path_to_folder: str): + def ensure_dir(self, path_to_folder: str) -> None: if not os.path.exists(path_to_folder): os.makedirs(path_to_folder) def create_dataset_specified_frames_per_video( - self, frames_per_video: List[int], w=32, h=32, c=3 - ): + self, frames_per_video: list[int], w: int = 32, h: int = 32, c: int = 3 + ) -> None: self.input_dir = tempfile.mkdtemp() self.ensure_dir(self.input_dir) self.frames_over_videos = [ @@ -58,9 +59,16 @@ def create_dataset_specified_frames_per_video( for frames in self.frames_over_videos: path = os.path.join(self.input_dir, f"output-{len(frames):03}.avi") print(path) - out = torchvision.io.write_video(path, self.frames_over_videos, frames) - - def create_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3): + out = torchvision.io.write_video(filename=path, video_array=frames, fps=1) + + def create_dataset( + self, + n_videos: int = 5, + n_frames_per_video: int = 10, + w: int = 32, + h: int = 32, + c: int = 3, + ) -> None: self.n_videos = n_videos self.n_frames_per_video = n_frames_per_video @@ -72,13 +80,15 @@ def create_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3): for i in range(n_videos): path = os.path.join(self.input_dir, f"output-{i}.avi") print(path) - out = torchvision.io.write_video(path, self.frames, n_frames_per_video) + out = torchvision.io.write_video( + filename=path, video_array=self.frames, fps=1 + ) @unittest.skipUnless( PYAV_AVAILABLE and VIDEO_READER_AVAILABLE, "pyav and video_reader backends must be both available", ) - def test_video_similar_timestamps_for_different_backends(self): + def test_video_similar_timestamps_for_different_backends(self) -> None: frames_per_video = list(range(1, 10)) self.create_dataset_specified_frames_per_video(frames_per_video) @@ -115,7 +125,7 @@ def test_video_similar_timestamps_for_different_backends(self): shutil.rmtree(self.input_dir) - def test_video_dataset_tqdm_args(self): + def test_video_dataset_tqdm_args(self) -> None: self.create_dataset() desc = "test_video_dataset_tqdm_args description asdf" f = io.StringIO() @@ -131,7 +141,7 @@ def test_video_dataset_tqdm_args(self): printed = f.getvalue() self.assertTrue(desc in printed) - def test_video_dataset_init_dataloader(self): + def test_video_dataset_init_dataloader(self) -> None: self.create_dataset() dataset_4_workers = LightlyDataset( self.input_dir, num_workers_video_frame_counting=4 @@ -164,7 +174,7 @@ def test_video_dataset_from_folder__video_reader(self) -> None: torchvision.set_video_backend("video_reader") self._test_video_dataset_from_folder() - def _test_video_dataset_from_folder(self): + def _test_video_dataset_from_folder(self) -> None: self.create_dataset() # create dataset @@ -190,7 +200,7 @@ def _test_video_dataset_from_folder(self): shutil.rmtree(self.input_dir) - def test_video_dataset_no_read_rights(self): + def test_video_dataset_no_read_rights(self) -> None: n_videos = 7 self.create_dataset(n_videos=n_videos) @@ -216,21 +226,27 @@ def test_video_dataset_no_read_rights(self): dataset = LightlyDataset(self.input_dir) @unittest.skipUnless(PYAV_AVAILABLE, "PyAV unavailable") - def test_video_dataset_non_increasing_timestamps__pyav(self): + def test_video_dataset_non_increasing_timestamps__pyav(self) -> None: torchvision.set_video_backend("pyav") self._test_video_dataset_non_increasing_timestamps() @unittest.skipUnless(VIDEO_READER_AVAILABLE, "video_reader unavailable") - def test_video_dataset_non_increasing_timestamps__video_reader(self): + def test_video_dataset_non_increasing_timestamps__video_reader(self) -> None: torchvision.set_video_backend("video_reader") self._test_video_dataset_non_increasing_timestamps() - def _test_video_dataset_non_increasing_timestamps(self): + def _test_video_dataset_non_increasing_timestamps(self) -> None: self.create_dataset(n_videos=2, n_frames_per_video=5) # overwrite the _make_dataset function to return a wrong timestamp - def _make_dataset_with_non_increasing_timestamps(*args, **kwargs): + def _make_dataset_with_non_increasing_timestamps( + *args: tuple[ + str, tuple[str] | None, bool | None, str, dict[str, Any] | None, int + ], + **kwargs: dict[str, Any], + ) -> tuple[list[str], list[int], list[int], list[int]]: video_instances, timestamps, offsets, fpss = _make_dataset(*args, **kwargs) + print(video_instances, timestamps, offsets, fpss) # set timestamp of 4th frame in 1st video to timestamp of 2nd frame. timestamps[0][3] = timestamps[0][1] return video_instances, timestamps, offsets, fpss @@ -272,16 +288,16 @@ def _make_dataset_with_non_increasing_timestamps(*args, **kwargs): self.assertEqual(total_frames, len(dataset)) @unittest.skipUnless(PYAV_AVAILABLE, "PyAV unavailable") - def test_video_dataset_dataloader__pyav(self): + def test_video_dataset_dataloader__pyav(self) -> None: torchvision.set_video_backend("pyav") self._test_video_dataset_dataloader() @unittest.skipUnless(VIDEO_READER_AVAILABLE, "video_reader unavailable") - def test_video_dataset_dataloader__video_reader(self): + def test_video_dataset_dataloader__video_reader(self) -> None: torchvision.set_video_backend("video_reader") self._test_video_dataset_dataloader() - def _test_video_dataset_dataloader(self): + def _test_video_dataset_dataloader(self) -> None: self.create_dataset() dataset = VideoDataset(self.input_dir, extensions=self.extensions) dataloader = torch.utils.data.DataLoader( @@ -294,7 +310,7 @@ def _test_video_dataset_dataloader(self): for batch in dataloader: pass - def test_find_non_increasing_timestamps(self): + def test_find_non_increasing_timestamps(self) -> None: # no timestamps non_increasing = _find_non_increasing_timestamps([]) self.assertListEqual(non_increasing, [])