From 8366a47a169b7ece4e164d97bb09a284be38642d Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Sat, 12 Oct 2024 01:15:27 +0530 Subject: [PATCH 1/5] Implementation of GaussianMixtureMasksTransform --- lightly/transforms/__init__.py | 1 + .../gaussian_mixture_masks_transform.py | 103 ++++++++++++++++++ .../transforms/test_gaussian_mixture_masks.py | 10 ++ 3 files changed, 114 insertions(+) create mode 100644 lightly/transforms/gaussian_mixture_masks_transform.py create mode 100644 tests/transforms/test_gaussian_mixture_masks.py diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 949fbe905..2cdb45963 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -18,6 +18,7 @@ from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform from lightly.transforms.fast_siam_transform import FastSiamTransform from lightly.transforms.gaussian_blur import GaussianBlur +from lightly.transforms.gaussian_mixture_masks_transform import GaussianMixtureMasks from lightly.transforms.irfft2d_transform import IRFFT2DTransform from lightly.transforms.jigsaw import Jigsaw from lightly.transforms.mae_transform import MAETransform diff --git a/lightly/transforms/gaussian_mixture_masks_transform.py b/lightly/transforms/gaussian_mixture_masks_transform.py new file mode 100644 index 000000000..4ab98aeaf --- /dev/null +++ b/lightly/transforms/gaussian_mixture_masks_transform.py @@ -0,0 +1,103 @@ +from typing import Tuple + +import torch +import torch.fft +from torch import Tensor + +from lightly.transforms.irfft2d_transform import IRFFT2DTransform +from lightly.transforms.rfft2d_transform import RFFT2DTransform + + +class GaussianMixtureMasks: + """Applies a Gaussian Mixture Mask in the Fourier domain to RGB images. + + The mask is created using random Gaussian kernels, which are applied in + the frequency domain via RFFT2D, and then the IRFFT2D is used to return + to the spatial domain. The transformation is applied to each RGB channel separately. + + Attributes: + num_gaussians: Number of Gaussian kernels to generate in the mixture mask. + std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians. + """ + + def __init__(self, num_gaussians: int = 20, std_range: Tuple[int, int] = (10, 15)): + """Initializes GaussianMixtureMasks with the given parameters. + + Args: + num_gaussians: Number of Gaussian kernels to generate in the mixture mask. + std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians. + """ + + self.rfft2d_transform = RFFT2DTransform() + self.num_gaussians = num_gaussians + self.std_range = std_range + + def gaussian_kernel( + self, size: Tuple[int, int], sigma: Tensor, center: Tensor + ) -> Tensor: + """Generates a 2D Gaussian kernel. + + Args: + size: Tuple specifying the dimensions of the Gaussian kernel (C, H, W). + sigma: Tensor specifying the standard deviation of the Gaussian. + center: Tensor specifying the center of the Gaussian kernel. + + Returns: + Tensor: A 2D Gaussian kernel. + """ + u, v = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1])) + u0, v0 = center + gaussian = torch.exp( + -((u - u0) ** 2 / (2 * sigma[0] ** 2) + (v - v0) ** 2 / (2 * sigma[1] ** 2)) + ) + + return gaussian + + def apply_gaussian_mixture_mask( + self, image_channel: Tensor, num_gaussians: int, std: Tuple[int, int] + ) -> Tensor: + """Applies the Gaussian mixture mask to a single channel in the frequency domain. + + Args: + image_channel: Tensor representing a single channel of the image. + num_gaussians: Number of Gaussian kernels to generate in the mask. + std: Tuple specifying the standard deviation range for the Gaussians. + + Returns: + Tensor: Image after applying the Gaussian mixture mask. + """ + image_size = image_channel[0].shape + + self.irfft2d_transform = IRFFT2DTransform((image_size[0], image_size[1])) + f_transform = self.rfft2d_transform(image_channel) + + size = f_transform[0].shape + + mask = torch.ones(size) + + for _ in range(num_gaussians): + u0 = torch.randint(0, size[0], (1,)) + v0 = torch.randint(0, size[1], (1,)) + center = torch.tensor((u0, v0)) + sigma = torch.rand(2) * 5 + 10 + + g_kernel = self.gaussian_kernel((size[0], size[1]), sigma, center) + mask -= g_kernel + + filtered_f_transform = f_transform * mask + filtered_image = self.irfft2d_transform(filtered_f_transform).abs() + return filtered_image + + def __call__(self, image_tensor: Tensor) -> Tensor: + """Applies the Gaussian mixture mask transformation to the input image. + + Args: + image_tensor: Tensor representing an RGB image of shape (C, H, W). + + Returns: + Tensor: The transformed image after applying the Gaussian mixture mask. + """ + transformed_channel: Tensor = self.apply_gaussian_mixture_mask( + image_tensor, self.num_gaussians, self.std_range + ) + return transformed_channel diff --git a/tests/transforms/test_gaussian_mixture_masks.py b/tests/transforms/test_gaussian_mixture_masks.py new file mode 100644 index 000000000..ae2fd9fe8 --- /dev/null +++ b/tests/transforms/test_gaussian_mixture_masks.py @@ -0,0 +1,10 @@ +import torch + +from lightly.transforms import GaussianMixtureMasks + + +def test() -> None: + transform = GaussianMixtureMasks(20, (10, 15)) + image = torch.rand(3, 32, 32) + output = transform(image) + assert output.shape == (3, 32, 32) From c9fe57a5db911239ec6fbe95859bcfe66503e234 Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Mon, 14 Oct 2024 19:17:45 +0530 Subject: [PATCH 2/5] Implementing requested changes on GMM --- lightly/transforms/__init__.py | 2 +- .../gaussian_mixture_masks_transform.py | 54 +++++++++++-------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 2cdb45963..338def43a 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -18,7 +18,7 @@ from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform from lightly.transforms.fast_siam_transform import FastSiamTransform from lightly.transforms.gaussian_blur import GaussianBlur -from lightly.transforms.gaussian_mixture_masks_transform import GaussianMixtureMasks +from lightly.transforms.gaussian_mixture_masks_transform import GaussianMixtureMask from lightly.transforms.irfft2d_transform import IRFFT2DTransform from lightly.transforms.jigsaw import Jigsaw from lightly.transforms.mae_transform import MAETransform diff --git a/lightly/transforms/gaussian_mixture_masks_transform.py b/lightly/transforms/gaussian_mixture_masks_transform.py index 4ab98aeaf..e911ed19f 100644 --- a/lightly/transforms/gaussian_mixture_masks_transform.py +++ b/lightly/transforms/gaussian_mixture_masks_transform.py @@ -8,19 +8,21 @@ from lightly.transforms.rfft2d_transform import RFFT2DTransform -class GaussianMixtureMasks: - """Applies a Gaussian Mixture Mask in the Fourier domain to RGB images. +class GaussianMixtureMask: + """Applies a Gaussian Mixture Mask in the Fourier domain to a single-channel image. The mask is created using random Gaussian kernels, which are applied in the frequency domain via RFFT2D, and then the IRFFT2D is used to return - to the spatial domain. The transformation is applied to each RGB channel separately. + to the spatial domain. The transformation is applied to each image channel separately. Attributes: num_gaussians: Number of Gaussian kernels to generate in the mixture mask. std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians. """ - def __init__(self, num_gaussians: int = 20, std_range: Tuple[int, int] = (10, 15)): + def __init__( + self, num_gaussians: int = 20, std_range: Tuple[float, float] = (10, 15) + ): """Initializes GaussianMixtureMasks with the given parameters. Args: @@ -29,6 +31,7 @@ def __init__(self, num_gaussians: int = 20, std_range: Tuple[int, int] = (10, 15 """ self.rfft2d_transform = RFFT2DTransform() + self.num_gaussians = num_gaussians self.std_range = std_range @@ -38,7 +41,7 @@ def gaussian_kernel( """Generates a 2D Gaussian kernel. Args: - size: Tuple specifying the dimensions of the Gaussian kernel (C, H, W). + size: Tuple specifying the dimensions of the Gaussian kernel (H, W). sigma: Tensor specifying the standard deviation of the Gaussian. center: Tensor specifying the center of the Gaussian kernel. @@ -46,6 +49,8 @@ def gaussian_kernel( Tensor: A 2D Gaussian kernel. """ u, v = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1])) + u = u.to(sigma.device) + v = v.to(sigma.device) u0, v0 = center gaussian = torch.exp( -((u - u0) ** 2 / (2 * sigma[0] ** 2) + (v - v0) ** 2 / (2 * sigma[1] ** 2)) @@ -54,50 +59,53 @@ def gaussian_kernel( return gaussian def apply_gaussian_mixture_mask( - self, image_channel: Tensor, num_gaussians: int, std: Tuple[int, int] + self, freq_image: Tensor, num_gaussians: int, std: Tuple[int, int] ) -> Tensor: - """Applies the Gaussian mixture mask to a single channel in the frequency domain. + """Applies the Gaussian mixture mask to a frequency-domain image. Args: - image_channel: Tensor representing a single channel of the image. + freq_image: Tensor representing the frequency-domain image of shape (C, H, W//2+1). num_gaussians: Number of Gaussian kernels to generate in the mask. std: Tuple specifying the standard deviation range for the Gaussians. Returns: Tensor: Image after applying the Gaussian mixture mask. """ - image_size = image_channel[0].shape + image_size = freq_image.shape[1:] + original_height = image_size[0] + original_width = 2 * (image_size[1] - 1) + + original_shape = (original_height, original_width) - self.irfft2d_transform = IRFFT2DTransform((image_size[0], image_size[1])) - f_transform = self.rfft2d_transform(image_channel) + self.irfft2d_transform = IRFFT2DTransform(original_shape) - size = f_transform[0].shape + size = freq_image[0].shape - mask = torch.ones(size) + mask = freq_image.new_ones(freq_image.shape) for _ in range(num_gaussians): - u0 = torch.randint(0, size[0], (1,)) - v0 = torch.randint(0, size[1], (1,)) - center = torch.tensor((u0, v0)) - sigma = torch.rand(2) * 5 + 10 + u0 = torch.randint(0, size[0], (1,), device=freq_image.device) + v0 = torch.randint(0, size[1], (1,), device=freq_image.device) + center = torch.tensor((u0, v0), device=freq_image.device) + sigma = torch.rand(2, device=freq_image.device) * (std[1] - std[0]) + std[0] g_kernel = self.gaussian_kernel((size[0], size[1]), sigma, center) mask -= g_kernel - filtered_f_transform = f_transform * mask - filtered_image = self.irfft2d_transform(filtered_f_transform).abs() + filtered_freq_image = freq_image * mask + filtered_image = self.irfft2d_transform(filtered_freq_image).abs() return filtered_image - def __call__(self, image_tensor: Tensor) -> Tensor: - """Applies the Gaussian mixture mask transformation to the input image. + def __call__(self, freq_image: Tensor) -> Tensor: + """Applies the Gaussian mixture mask transformation to the input frequency-domain image. Args: - image_tensor: Tensor representing an RGB image of shape (C, H, W). + freq_image: Tensor representing a frequency-domain image of shape (C, H, W//2+1). Returns: Tensor: The transformed image after applying the Gaussian mixture mask. """ transformed_channel: Tensor = self.apply_gaussian_mixture_mask( - image_tensor, self.num_gaussians, self.std_range + freq_image, self.num_gaussians, self.std_range ) return transformed_channel From fa2fb1715d94de8b4727c6c45af4b3f29460bce5 Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Mon, 14 Oct 2024 19:25:50 +0530 Subject: [PATCH 3/5] minor bug fixes for GMM --- lightly/transforms/gaussian_mixture_masks_transform.py | 2 +- tests/transforms/test_gaussian_mixture_masks.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightly/transforms/gaussian_mixture_masks_transform.py b/lightly/transforms/gaussian_mixture_masks_transform.py index e911ed19f..c9804db3b 100644 --- a/lightly/transforms/gaussian_mixture_masks_transform.py +++ b/lightly/transforms/gaussian_mixture_masks_transform.py @@ -59,7 +59,7 @@ def gaussian_kernel( return gaussian def apply_gaussian_mixture_mask( - self, freq_image: Tensor, num_gaussians: int, std: Tuple[int, int] + self, freq_image: Tensor, num_gaussians: int, std: Tuple[float, float] ) -> Tensor: """Applies the Gaussian mixture mask to a frequency-domain image. diff --git a/tests/transforms/test_gaussian_mixture_masks.py b/tests/transforms/test_gaussian_mixture_masks.py index ae2fd9fe8..8f2c8678b 100644 --- a/tests/transforms/test_gaussian_mixture_masks.py +++ b/tests/transforms/test_gaussian_mixture_masks.py @@ -1,10 +1,10 @@ import torch -from lightly.transforms import GaussianMixtureMasks +from lightly.transforms import GaussianMixtureMask def test() -> None: - transform = GaussianMixtureMasks(20, (10, 15)) + transform = GaussianMixtureMask(20, (10, 15)) image = torch.rand(3, 32, 32) output = transform(image) assert output.shape == (3, 32, 32) From 2c9e7c8c79798affef146fd408176b277c186e8f Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Mon, 14 Oct 2024 19:37:09 +0530 Subject: [PATCH 4/5] removed rfft2d init from GMM --- lightly/transforms/gaussian_mixture_masks_transform.py | 4 ---- tests/transforms/test_gaussian_mixture_masks.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/lightly/transforms/gaussian_mixture_masks_transform.py b/lightly/transforms/gaussian_mixture_masks_transform.py index c9804db3b..b9cc2bb6e 100644 --- a/lightly/transforms/gaussian_mixture_masks_transform.py +++ b/lightly/transforms/gaussian_mixture_masks_transform.py @@ -5,7 +5,6 @@ from torch import Tensor from lightly.transforms.irfft2d_transform import IRFFT2DTransform -from lightly.transforms.rfft2d_transform import RFFT2DTransform class GaussianMixtureMask: @@ -29,9 +28,6 @@ def __init__( num_gaussians: Number of Gaussian kernels to generate in the mixture mask. std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians. """ - - self.rfft2d_transform = RFFT2DTransform() - self.num_gaussians = num_gaussians self.std_range = std_range diff --git a/tests/transforms/test_gaussian_mixture_masks.py b/tests/transforms/test_gaussian_mixture_masks.py index 8f2c8678b..40b51d457 100644 --- a/tests/transforms/test_gaussian_mixture_masks.py +++ b/tests/transforms/test_gaussian_mixture_masks.py @@ -5,6 +5,6 @@ def test() -> None: transform = GaussianMixtureMask(20, (10, 15)) - image = torch.rand(3, 32, 32) + image = torch.rand(3, 32, 17) output = transform(image) assert output.shape == (3, 32, 32) From b68b473bcdcbbdaec907ea3d7dcf0a23a3496f2a Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Tue, 15 Oct 2024 19:34:29 +0530 Subject: [PATCH 5/5] GMM changed output to frequency domain output --- .../gaussian_mixture_masks_transform.py | 38 ++++++------------- .../transforms/test_gaussian_mixture_masks.py | 2 +- 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/lightly/transforms/gaussian_mixture_masks_transform.py b/lightly/transforms/gaussian_mixture_masks_transform.py index b9cc2bb6e..6d89d41c6 100644 --- a/lightly/transforms/gaussian_mixture_masks_transform.py +++ b/lightly/transforms/gaussian_mixture_masks_transform.py @@ -4,15 +4,12 @@ import torch.fft from torch import Tensor -from lightly.transforms.irfft2d_transform import IRFFT2DTransform - class GaussianMixtureMask: - """Applies a Gaussian Mixture Mask in the Fourier domain to a single-channel image. + """Applies a Gaussian Mixture Mask in the Fourier domain to an image. The mask is created using random Gaussian kernels, which are applied in - the frequency domain via RFFT2D, and then the IRFFT2D is used to return - to the spatial domain. The transformation is applied to each image channel separately. + the frequency domain. Attributes: num_gaussians: Number of Gaussian kernels to generate in the mixture mask. @@ -42,7 +39,7 @@ def gaussian_kernel( center: Tensor specifying the center of the Gaussian kernel. Returns: - Tensor: A 2D Gaussian kernel. + A 2D Gaussian kernel tensor. """ u, v = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1])) u = u.to(sigma.device) @@ -65,32 +62,22 @@ def apply_gaussian_mixture_mask( std: Tuple specifying the standard deviation range for the Gaussians. Returns: - Tensor: Image after applying the Gaussian mixture mask. + Image tensor in frequency domain after applying the Gaussian mixture mask. """ - image_size = freq_image.shape[1:] - original_height = image_size[0] - original_width = 2 * (image_size[1] - 1) - - original_shape = (original_height, original_width) - - self.irfft2d_transform = IRFFT2DTransform(original_shape) - - size = freq_image[0].shape - + (C, U, V) = freq_image.shape mask = freq_image.new_ones(freq_image.shape) for _ in range(num_gaussians): - u0 = torch.randint(0, size[0], (1,), device=freq_image.device) - v0 = torch.randint(0, size[1], (1,), device=freq_image.device) + u0 = torch.randint(0, U, (1,), device=freq_image.device) + v0 = torch.randint(0, V, (1,), device=freq_image.device) center = torch.tensor((u0, v0), device=freq_image.device) sigma = torch.rand(2, device=freq_image.device) * (std[1] - std[0]) + std[0] - g_kernel = self.gaussian_kernel((size[0], size[1]), sigma, center) - mask -= g_kernel + g_kernel = self.gaussian_kernel((U, V), sigma, center) + mask *= 1 - g_kernel.unsqueeze(0) filtered_freq_image = freq_image * mask - filtered_image = self.irfft2d_transform(filtered_freq_image).abs() - return filtered_image + return filtered_freq_image def __call__(self, freq_image: Tensor) -> Tensor: """Applies the Gaussian mixture mask transformation to the input frequency-domain image. @@ -99,9 +86,8 @@ def __call__(self, freq_image: Tensor) -> Tensor: freq_image: Tensor representing a frequency-domain image of shape (C, H, W//2+1). Returns: - Tensor: The transformed image after applying the Gaussian mixture mask. + Image tensor in frequency domain after applying the Gaussian mixture mask. """ - transformed_channel: Tensor = self.apply_gaussian_mixture_mask( + return self.apply_gaussian_mixture_mask( freq_image, self.num_gaussians, self.std_range ) - return transformed_channel diff --git a/tests/transforms/test_gaussian_mixture_masks.py b/tests/transforms/test_gaussian_mixture_masks.py index 40b51d457..db687000b 100644 --- a/tests/transforms/test_gaussian_mixture_masks.py +++ b/tests/transforms/test_gaussian_mixture_masks.py @@ -7,4 +7,4 @@ def test() -> None: transform = GaussianMixtureMask(20, (10, 15)) image = torch.rand(3, 32, 17) output = transform(image) - assert output.shape == (3, 32, 32) + assert output.shape == image.shape