Skip to content

Commit

Permalink
Merge branch 'cleanup-docstrings-data' of https://github.com/ChiragAg…
Browse files Browse the repository at this point in the history
…g5k/lightly into cleanup-docstrings-data
  • Loading branch information
philippmwirth committed Nov 22, 2024
2 parents 522a73f + 00d7593 commit 4a1933f
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Lightly AI
lightly.data
lightly.loss
lightly.models
lightly.models.utils
lightly.transforms
lightly.utils

Expand Down
5 changes: 5 additions & 0 deletions docs/source/lightly.models.utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
lightly.models.utils
=====================

.. automodule:: lightly.models.utils
:members:
9 changes: 5 additions & 4 deletions lightly/data/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
from __future__ import annotations

import os
from typing import Any, Callable, Dict, Optional, Tuple
Expand Down Expand Up @@ -34,7 +35,7 @@
VIDEO_EXTENSIONS = (".mp4", ".mov", ".avi", ".mpg", ".hevc", ".m4v", ".webm", ".mpeg")


def _dir_contains_videos(root: str, extensions: Tuple[str, ...]) -> bool:
def _dir_contains_videos(root: str, extensions: tuple[str, ...]) -> bool:
"""Checks whether the directory contains video files.
Args:
Expand All @@ -48,7 +49,7 @@ def _dir_contains_videos(root: str, extensions: Tuple[str, ...]) -> bool:
return any(f.name.lower().endswith(extensions) for f in scan_dir)


def _contains_videos(root: str, extensions: Tuple[str, ...]) -> bool:
def _contains_videos(root: str, extensions: tuple[str, ...]) -> bool:
"""Checks whether the directory or any subdirectory contains video files.
Args:
Expand Down Expand Up @@ -92,8 +93,8 @@ def _contains_subdirs(root: str) -> bool:
def _load_dataset_from_folder(
root: str,
transform: Callable[[Any], Any],
is_valid_file: Optional[Callable[[str], bool]] = None,
tqdm_args: Optional[Dict[str, Any]] = None,
is_valid_file: Callable[[str], bool] | None,
tqdm_args: dict[str, Any] | None,
num_workers_video_frame_counting: int = 0,
) -> datasets.VisionDataset:
"""Initializes a dataset from a folder.
Expand Down
4 changes: 2 additions & 2 deletions lightly/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved
from __future__ import annotations

import os
from typing import List, Tuple

import tqdm.contrib.concurrent as concurrent
from PIL import Image, UnidentifiedImageError

from lightly.data import LightlyDataset


def check_images(data_dir: str) -> Tuple[List[str], List[str]]:
def check_images(data_dir: str) -> tuple[list[str], list[str]]:
"""Identifies corrupt and healthy images in the specified directory.
The function attempts to open each image file in the directory to verify
Expand Down
57 changes: 57 additions & 0 deletions lightly/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,63 @@
from timm.models.vision_transformer import VisionTransformer


def pool_masked(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
) -> Tensor:
"""Reduce image feature maps (B, C, H, W) or (C, H, W) according to an integer
index given by `mask` (B, H, W) or (H, W).
Args:
source: Float tensor of shape (B, C, H, W) or (C, H, W) to be reduced.
mask: Integer tensor of shape (B, H, W) or (H, W) containing the integer indices.
reduce: The reduction operation to be applied, one of 'prod', 'mean', 'amax' or
'amin'. Defaults to 'mean'.
num_cls: The number of classes in the possible masks. If None, the number of classes
is inferred from the unique elements in `mask`. This is useful when not all
classes are present in the mask.
Returns:
A tensor of shape (B, C, N) or (C, N) where N is the number of unique elements
in `mask` or `num_cls` if specified.
"""
if source.dim() == 3:
return _mask_reduce(source, mask, reduce, num_cls)
elif source.dim() == 4:
return _mask_reduce_batched(source, mask, num_cls)
else:
raise ValueError("source must have 3 or 4 dimensions")


def _mask_reduce(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
) -> Tensor:
output = _mask_reduce_batched(
source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls
)
return output.squeeze(0)


def _mask_reduce_batched(
source: Tensor, mask: Tensor, num_cls: Optional[int] = None
) -> Tensor:
b, c, h, w = source.shape
if num_cls is None:
cls = mask.unique(sorted=True)
else:
cls = torch.arange(num_cls, device=mask.device)
num_cls = cls.size(0)
# create output tensor
output = source.new_zeros((b, c, num_cls)) # (B C N)
mask = mask.unsqueeze(1).expand(-1, c, -1, -1).view(b, c, -1) # (B C HW)
source = source.view(b, c, -1) # (B C HW)
output.scatter_reduce_(
dim=2, index=mask, src=source, reduce="mean", include_self=False
) # (B C N)
# scatter_reduce_ produces NaNs if the count is zero
output = torch.nan_to_num(output, nan=0.0)
return output


@torch.no_grad()
def batch_shuffle(
batch: torch.Tensor, distributed: bool = False
Expand Down
116 changes: 116 additions & 0 deletions tests/models/test_ModelUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,132 @@

from lightly.models import utils
from lightly.models.utils import (
_mask_reduce,
_mask_reduce_batched,
_no_grad_trunc_normal,
activate_requires_grad,
batch_shuffle,
batch_unshuffle,
deactivate_requires_grad,
nearest_neighbors,
normalize_weight,
pool_masked,
update_momentum,
)

is_scatter_reduce_available = hasattr(Tensor, "scatter_reduce_")


@pytest.mark.skipif(
not is_scatter_reduce_available,
reason="scatter operations require torch >= 1.12.0",
)
class TestMaskReduce:
@pytest.fixture()
def mask1(self) -> Tensor:
return torch.tensor([[0, 0], [1, 2]], dtype=torch.int64)

@pytest.fixture()
def mask2(self) -> Tensor:
return torch.tensor([[1, 0], [0, 1]], dtype=torch.int64)

@pytest.fixture()
def feature_map1(self) -> Tensor:
feature_map = torch.tensor(
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]],
dtype=torch.float32,
) # (C H W) = (3, 2, 2)
return feature_map

@pytest.fixture()
def feature_map2(self) -> Tensor:
feature_map = torch.tensor(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
dtype=torch.float32,
) # (C H W) = (3, 2, 2)
return feature_map

@pytest.fixture()
def expected_result1(self) -> Tensor:
res = torch.tensor(
[[0.5, 2.0, 3.0], [4.5, 6.0, 7.0], [8.5, 10.0, 11.0]], dtype=torch.float32
)
return res

@pytest.fixture()
def expected_result2(self) -> Tensor:
res = torch.tensor(
[[2.5, 2.5, 0.0], [6.5, 6.5, 0.0], [10.5, 10.5, 0.0]], dtype=torch.float32
)
return res

def test__mask_reduce_batched(
self,
feature_map1: Tensor,
feature_map2: Tensor,
mask1: Tensor,
mask2: Tensor,
expected_result1: Tensor,
expected_result2: Tensor,
) -> None:
feature_map = torch.stack([feature_map1, feature_map2], dim=0)
mask = torch.stack([mask1, mask2], dim=0)
expected_result = torch.stack([expected_result1, expected_result2], dim=0)

out = _mask_reduce_batched(feature_map, mask, num_cls=3)
assert (out == expected_result).all()

def test_masked_pooling_manual(
self, feature_map2: Tensor, mask2: Tensor, expected_result2: Tensor
) -> None:
out_manual = pool_masked(
feature_map2.unsqueeze(0), mask2.unsqueeze(0), num_cls=2
)
assert out_manual.shape == (1, 3, 2)
assert (out_manual == expected_result2[:, :2]).all()

def test_masked_pooling_auto(
self, feature_map2: Tensor, mask2: Tensor, expected_result2: Tensor
) -> None:
out_auto = pool_masked(
feature_map2.unsqueeze(0), mask2.unsqueeze(0), num_cls=None
)
assert out_auto.shape == (1, 3, 2)
assert (out_auto == expected_result2[:, :2]).all()

@pytest.mark.parametrize(
"feature_map, mask, expected_result",
[
(
torch.tensor(
[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]],
dtype=torch.float32,
),
torch.tensor([[0, 0], [1, 2]], dtype=torch.int64),
torch.tensor(
[[0.5, 2.0, 3.0], [4.5, 6.0, 7.0], [8.5, 10.0, 11.0]],
dtype=torch.float32,
),
),
(
torch.tensor(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]],
dtype=torch.float32,
),
torch.tensor([[1, 0], [0, 1]], dtype=torch.int64),
torch.tensor(
[[2.5, 2.5, 0.0], [6.5, 6.5, 0.0], [10.5, 10.5, 0.0]],
dtype=torch.float32,
),
),
],
)
def test__mask_reduce(
self, feature_map: Tensor, mask: Tensor, expected_result: Tensor
) -> None:
out = _mask_reduce(feature_map, mask, num_cls=3)
assert (out == expected_result).all()


def has_grad(model: nn.Module):
"""Helper method to check if a model has `requires_grad` set to True"""
Expand Down

0 comments on commit 4a1933f

Please sign in to comment.