Skip to content

Commit

Permalink
refactor: remove opencv dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
vectorvp committed Nov 22, 2024
1 parent 2af5181 commit 53aaa49
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 24 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ exclude = '''(?x)(
tests/core/test_Core.py |
tests/data/test_multi_view_collate.py |
tests/data/test_data_collate.py |
tests/data/test_VideoDataset.py |
tests/data/test_LightlySubset.py |
tests/data/test_LightlyDataset.py |
tests/embedding/test_callbacks.py |
Expand Down
62 changes: 39 additions & 23 deletions tests/data/test_VideoDataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import contextlib
import io
import os
import shutil
import tempfile
import unittest
from fractions import Fraction
from typing import List
from typing import Any
from unittest import mock

import numpy as np
import PIL
import torch
import torchvision
Expand Down Expand Up @@ -36,17 +37,17 @@
PYAV_AVAILABLE or VIDEO_READER_AVAILABLE, "No video backend available"
)
class TestVideoDataset(unittest.TestCase):
def tearDown(self):
def tearDown(self) -> None:
# Make sure to set the default backend to not interfere with other tests.
torchvision.set_video_backend(DEFAULT_BACKEND)

def ensure_dir(self, path_to_folder: str):
def ensure_dir(self, path_to_folder: str) -> None:
if not os.path.exists(path_to_folder):
os.makedirs(path_to_folder)

def create_dataset_specified_frames_per_video(
self, frames_per_video: List[int], w=32, h=32, c=3
):
self, frames_per_video: list[int], w: int = 32, h: int = 32, c: int = 3
) -> None:
self.input_dir = tempfile.mkdtemp()
self.ensure_dir(self.input_dir)
self.frames_over_videos = [
Expand All @@ -58,9 +59,16 @@ def create_dataset_specified_frames_per_video(
for frames in self.frames_over_videos:
path = os.path.join(self.input_dir, f"output-{len(frames):03}.avi")
print(path)
out = torchvision.io.write_video(path, self.frames_over_videos, frames)

def create_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3):
out = torchvision.io.write_video(filename=path, video_array=frames, fps=1)

def create_dataset(
self,
n_videos: int = 5,
n_frames_per_video: int = 10,
w: int = 32,
h: int = 32,
c: int = 3,
) -> None:
self.n_videos = n_videos
self.n_frames_per_video = n_frames_per_video

Expand All @@ -72,13 +80,15 @@ def create_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3):
for i in range(n_videos):
path = os.path.join(self.input_dir, f"output-{i}.avi")
print(path)
out = torchvision.io.write_video(path, self.frames, n_frames_per_video)
out = torchvision.io.write_video(
filename=path, video_array=self.frames, fps=1
)

@unittest.skipUnless(
PYAV_AVAILABLE and VIDEO_READER_AVAILABLE,
"pyav and video_reader backends must be both available",
)
def test_video_similar_timestamps_for_different_backends(self):
def test_video_similar_timestamps_for_different_backends(self) -> None:
frames_per_video = list(range(1, 10))
self.create_dataset_specified_frames_per_video(frames_per_video)

Expand Down Expand Up @@ -115,7 +125,7 @@ def test_video_similar_timestamps_for_different_backends(self):

shutil.rmtree(self.input_dir)

def test_video_dataset_tqdm_args(self):
def test_video_dataset_tqdm_args(self) -> None:
self.create_dataset()
desc = "test_video_dataset_tqdm_args description asdf"
f = io.StringIO()
Expand All @@ -131,7 +141,7 @@ def test_video_dataset_tqdm_args(self):
printed = f.getvalue()
self.assertTrue(desc in printed)

def test_video_dataset_init_dataloader(self):
def test_video_dataset_init_dataloader(self) -> None:
self.create_dataset()
dataset_4_workers = LightlyDataset(
self.input_dir, num_workers_video_frame_counting=4
Expand Down Expand Up @@ -164,7 +174,7 @@ def test_video_dataset_from_folder__video_reader(self) -> None:
torchvision.set_video_backend("video_reader")
self._test_video_dataset_from_folder()

def _test_video_dataset_from_folder(self):
def _test_video_dataset_from_folder(self) -> None:
self.create_dataset()

# create dataset
Expand All @@ -190,7 +200,7 @@ def _test_video_dataset_from_folder(self):

shutil.rmtree(self.input_dir)

def test_video_dataset_no_read_rights(self):
def test_video_dataset_no_read_rights(self) -> None:
n_videos = 7
self.create_dataset(n_videos=n_videos)

Expand All @@ -216,21 +226,27 @@ def test_video_dataset_no_read_rights(self):
dataset = LightlyDataset(self.input_dir)

@unittest.skipUnless(PYAV_AVAILABLE, "PyAV unavailable")
def test_video_dataset_non_increasing_timestamps__pyav(self):
def test_video_dataset_non_increasing_timestamps__pyav(self) -> None:
torchvision.set_video_backend("pyav")
self._test_video_dataset_non_increasing_timestamps()

@unittest.skipUnless(VIDEO_READER_AVAILABLE, "video_reader unavailable")
def test_video_dataset_non_increasing_timestamps__video_reader(self):
def test_video_dataset_non_increasing_timestamps__video_reader(self) -> None:
torchvision.set_video_backend("video_reader")
self._test_video_dataset_non_increasing_timestamps()

def _test_video_dataset_non_increasing_timestamps(self):
def _test_video_dataset_non_increasing_timestamps(self) -> None:
self.create_dataset(n_videos=2, n_frames_per_video=5)

# overwrite the _make_dataset function to return a wrong timestamp
def _make_dataset_with_non_increasing_timestamps(*args, **kwargs):
def _make_dataset_with_non_increasing_timestamps(
*args: tuple[
str, tuple[str] | None, bool | None, str, dict[str, Any] | None, int
],
**kwargs: dict[str, Any],
) -> tuple[list[str], list[int], list[int], list[int]]:
video_instances, timestamps, offsets, fpss = _make_dataset(*args, **kwargs)
print(video_instances, timestamps, offsets, fpss)
# set timestamp of 4th frame in 1st video to timestamp of 2nd frame.
timestamps[0][3] = timestamps[0][1]
return video_instances, timestamps, offsets, fpss
Expand Down Expand Up @@ -272,16 +288,16 @@ def _make_dataset_with_non_increasing_timestamps(*args, **kwargs):
self.assertEqual(total_frames, len(dataset))

@unittest.skipUnless(PYAV_AVAILABLE, "PyAV unavailable")
def test_video_dataset_dataloader__pyav(self):
def test_video_dataset_dataloader__pyav(self) -> None:
torchvision.set_video_backend("pyav")
self._test_video_dataset_dataloader()

@unittest.skipUnless(VIDEO_READER_AVAILABLE, "video_reader unavailable")
def test_video_dataset_dataloader__video_reader(self):
def test_video_dataset_dataloader__video_reader(self) -> None:
torchvision.set_video_backend("video_reader")
self._test_video_dataset_dataloader()

def _test_video_dataset_dataloader(self):
def _test_video_dataset_dataloader(self) -> None:
self.create_dataset()
dataset = VideoDataset(self.input_dir, extensions=self.extensions)
dataloader = torch.utils.data.DataLoader(
Expand All @@ -294,7 +310,7 @@ def _test_video_dataset_dataloader(self):
for batch in dataloader:
pass

def test_find_non_increasing_timestamps(self):
def test_find_non_increasing_timestamps(self) -> None:
# no timestamps
non_increasing = _find_non_increasing_timestamps([])
self.assertListEqual(non_increasing, [])
Expand Down

0 comments on commit 53aaa49

Please sign in to comment.