Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove opencv dependencies partial #1743

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ exclude = '''(?x)(
tests/core/test_Core.py |
tests/data/test_multi_view_collate.py |
tests/data/test_data_collate.py |
tests/data/test_VideoDataset.py |
tests/data/test_LightlySubset.py |
tests/data/test_LightlyDataset.py |
tests/embedding/test_callbacks.py |
Expand Down
76 changes: 44 additions & 32 deletions tests/data/test_VideoDataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

import contextlib
import io
import os
import shutil
import tempfile
import unittest
from fractions import Fraction
from typing import List
from typing import Any
from unittest import mock

import cv2
import numpy as np
import PIL
import torch
import torchvision
Expand Down Expand Up @@ -37,57 +37,63 @@
PYAV_AVAILABLE or VIDEO_READER_AVAILABLE, "No video backend available"
)
class TestVideoDataset(unittest.TestCase):
def tearDown(self):
def tearDown(self) -> None:
# Make sure to set the default backend to not interfere with other tests.
torchvision.set_video_backend(DEFAULT_BACKEND)

def ensure_dir(self, path_to_folder: str):
def ensure_dir(self, path_to_folder: str) -> None:
if not os.path.exists(path_to_folder):
os.makedirs(path_to_folder)

def create_dataset_specified_frames_per_video(
self, frames_per_video: List[int], w=32, h=32, c=3
):
self, frames_per_video: list[int], w: int = 32, h: int = 32, c: int = 3
) -> None:
self.input_dir = tempfile.mkdtemp()
self.ensure_dir(self.input_dir)
self.frames_over_videos = [
(np.random.randn(frames, w, h, c) * 255).astype(np.uint8)
torch.randint(low=0, high=256, size=(frames, h, w, c), dtype=torch.uint8)
for frames in frames_per_video
]

self.extensions = ".avi"

for frames in self.frames_over_videos:
path = os.path.join(self.input_dir, f"output-{len(frames):03}.avi")
print(path)
out = cv2.VideoWriter(path, 0, 1, (w, h))
for frame in frames:
out.write(frame)
out.release()
torchvision.io.write_video(
filename=path, video_array=frames, fps=1, video_codec="rawvideo"
)

def create_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3):
def create_dataset(
self,
n_videos: int = 5,
n_frames_per_video: int = 10,
w: int = 32,
h: int = 32,
c: int = 3,
) -> None:
self.n_videos = n_videos
self.n_frames_per_video = n_frames_per_video

self.input_dir = tempfile.mkdtemp()
self.ensure_dir(self.input_dir)
self.frames = (np.random.randn(n_frames_per_video, w, h, c) * 255).astype(
np.uint8
self.frames = torch.randint(
low=0, high=256, size=(n_frames_per_video, h, w, c), dtype=torch.uint8
)
self.extensions = ".avi"

for i in range(n_videos):
path = os.path.join(self.input_dir, f"output-{i}.avi")
print(path)
out = cv2.VideoWriter(path, 0, 1, (w, h))
for frame in self.frames:
out.write(frame)
out.release()
torchvision.io.write_video(
filename=path, video_array=self.frames, fps=1, video_codec="rawvideo"
)

@unittest.skipUnless(
PYAV_AVAILABLE and VIDEO_READER_AVAILABLE,
"pyav and video_reader backends must be both available",
)
def test_video_similar_timestamps_for_different_backends(self):
def test_video_similar_timestamps_for_different_backends(self) -> None:
frames_per_video = list(range(1, 10))
self.create_dataset_specified_frames_per_video(frames_per_video)

Expand Down Expand Up @@ -124,7 +130,7 @@ def test_video_similar_timestamps_for_different_backends(self):

shutil.rmtree(self.input_dir)

def test_video_dataset_tqdm_args(self):
def test_video_dataset_tqdm_args(self) -> None:
self.create_dataset()
desc = "test_video_dataset_tqdm_args description asdf"
f = io.StringIO()
Expand All @@ -140,7 +146,7 @@ def test_video_dataset_tqdm_args(self):
printed = f.getvalue()
self.assertTrue(desc in printed)

def test_video_dataset_init_dataloader(self):
def test_video_dataset_init_dataloader(self) -> None:
self.create_dataset()
dataset_4_workers = LightlyDataset(
self.input_dir, num_workers_video_frame_counting=4
Expand Down Expand Up @@ -173,7 +179,7 @@ def test_video_dataset_from_folder__video_reader(self) -> None:
torchvision.set_video_backend("video_reader")
self._test_video_dataset_from_folder()

def _test_video_dataset_from_folder(self):
def _test_video_dataset_from_folder(self) -> None:
self.create_dataset()

# create dataset
Expand All @@ -199,7 +205,7 @@ def _test_video_dataset_from_folder(self):

shutil.rmtree(self.input_dir)

def test_video_dataset_no_read_rights(self):
def test_video_dataset_no_read_rights(self) -> None:
n_videos = 7
self.create_dataset(n_videos=n_videos)

Expand All @@ -225,21 +231,27 @@ def test_video_dataset_no_read_rights(self):
dataset = LightlyDataset(self.input_dir)

@unittest.skipUnless(PYAV_AVAILABLE, "PyAV unavailable")
def test_video_dataset_non_increasing_timestamps__pyav(self):
def test_video_dataset_non_increasing_timestamps__pyav(self) -> None:
torchvision.set_video_backend("pyav")
self._test_video_dataset_non_increasing_timestamps()

@unittest.skipUnless(VIDEO_READER_AVAILABLE, "video_reader unavailable")
def test_video_dataset_non_increasing_timestamps__video_reader(self):
def test_video_dataset_non_increasing_timestamps__video_reader(self) -> None:
torchvision.set_video_backend("video_reader")
self._test_video_dataset_non_increasing_timestamps()

def _test_video_dataset_non_increasing_timestamps(self):
def _test_video_dataset_non_increasing_timestamps(self) -> None:
self.create_dataset(n_videos=2, n_frames_per_video=5)

# overwrite the _make_dataset function to return a wrong timestamp
def _make_dataset_with_non_increasing_timestamps(*args, **kwargs):
def _make_dataset_with_non_increasing_timestamps(
*args: tuple[
str, tuple[str] | None, bool | None, str, dict[str, Any] | None, int
],
**kwargs: dict[str, Any],
) -> tuple[list[str], list[int], list[int], list[int]]:
video_instances, timestamps, offsets, fpss = _make_dataset(*args, **kwargs)
print(video_instances, timestamps, offsets, fpss)
# set timestamp of 4th frame in 1st video to timestamp of 2nd frame.
timestamps[0][3] = timestamps[0][1]
return video_instances, timestamps, offsets, fpss
Expand Down Expand Up @@ -281,16 +293,16 @@ def _make_dataset_with_non_increasing_timestamps(*args, **kwargs):
self.assertEqual(total_frames, len(dataset))

@unittest.skipUnless(PYAV_AVAILABLE, "PyAV unavailable")
def test_video_dataset_dataloader__pyav(self):
def test_video_dataset_dataloader__pyav(self) -> None:
torchvision.set_video_backend("pyav")
self._test_video_dataset_dataloader()

@unittest.skipUnless(VIDEO_READER_AVAILABLE, "video_reader unavailable")
def test_video_dataset_dataloader__video_reader(self):
def test_video_dataset_dataloader__video_reader(self) -> None:
torchvision.set_video_backend("video_reader")
self._test_video_dataset_dataloader()

def _test_video_dataset_dataloader(self):
def _test_video_dataset_dataloader(self) -> None:
self.create_dataset()
dataset = VideoDataset(self.input_dir, extensions=self.extensions)
dataloader = torch.utils.data.DataLoader(
Expand All @@ -303,7 +315,7 @@ def _test_video_dataset_dataloader(self):
for batch in dataloader:
pass

def test_find_non_increasing_timestamps(self):
def test_find_non_increasing_timestamps(self) -> None:
# no timestamps
non_increasing = _find_non_increasing_timestamps([])
self.assertListEqual(non_increasing, [])
Expand Down