diff --git a/lightly/transforms/gaussian_mixture_masks_transform.py b/lightly/transforms/gaussian_mixture_masks_transform.py index b9cc2bb6e..6d89d41c6 100644 --- a/lightly/transforms/gaussian_mixture_masks_transform.py +++ b/lightly/transforms/gaussian_mixture_masks_transform.py @@ -4,15 +4,12 @@ import torch.fft from torch import Tensor -from lightly.transforms.irfft2d_transform import IRFFT2DTransform - class GaussianMixtureMask: - """Applies a Gaussian Mixture Mask in the Fourier domain to a single-channel image. + """Applies a Gaussian Mixture Mask in the Fourier domain to an image. The mask is created using random Gaussian kernels, which are applied in - the frequency domain via RFFT2D, and then the IRFFT2D is used to return - to the spatial domain. The transformation is applied to each image channel separately. + the frequency domain. Attributes: num_gaussians: Number of Gaussian kernels to generate in the mixture mask. @@ -42,7 +39,7 @@ def gaussian_kernel( center: Tensor specifying the center of the Gaussian kernel. Returns: - Tensor: A 2D Gaussian kernel. + A 2D Gaussian kernel tensor. """ u, v = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1])) u = u.to(sigma.device) @@ -65,32 +62,22 @@ def apply_gaussian_mixture_mask( std: Tuple specifying the standard deviation range for the Gaussians. Returns: - Tensor: Image after applying the Gaussian mixture mask. + Image tensor in frequency domain after applying the Gaussian mixture mask. """ - image_size = freq_image.shape[1:] - original_height = image_size[0] - original_width = 2 * (image_size[1] - 1) - - original_shape = (original_height, original_width) - - self.irfft2d_transform = IRFFT2DTransform(original_shape) - - size = freq_image[0].shape - + (C, U, V) = freq_image.shape mask = freq_image.new_ones(freq_image.shape) for _ in range(num_gaussians): - u0 = torch.randint(0, size[0], (1,), device=freq_image.device) - v0 = torch.randint(0, size[1], (1,), device=freq_image.device) + u0 = torch.randint(0, U, (1,), device=freq_image.device) + v0 = torch.randint(0, V, (1,), device=freq_image.device) center = torch.tensor((u0, v0), device=freq_image.device) sigma = torch.rand(2, device=freq_image.device) * (std[1] - std[0]) + std[0] - g_kernel = self.gaussian_kernel((size[0], size[1]), sigma, center) - mask -= g_kernel + g_kernel = self.gaussian_kernel((U, V), sigma, center) + mask *= 1 - g_kernel.unsqueeze(0) filtered_freq_image = freq_image * mask - filtered_image = self.irfft2d_transform(filtered_freq_image).abs() - return filtered_image + return filtered_freq_image def __call__(self, freq_image: Tensor) -> Tensor: """Applies the Gaussian mixture mask transformation to the input frequency-domain image. @@ -99,9 +86,8 @@ def __call__(self, freq_image: Tensor) -> Tensor: freq_image: Tensor representing a frequency-domain image of shape (C, H, W//2+1). Returns: - Tensor: The transformed image after applying the Gaussian mixture mask. + Image tensor in frequency domain after applying the Gaussian mixture mask. """ - transformed_channel: Tensor = self.apply_gaussian_mixture_mask( + return self.apply_gaussian_mixture_mask( freq_image, self.num_gaussians, self.std_range ) - return transformed_channel diff --git a/tests/transforms/test_gaussian_mixture_masks.py b/tests/transforms/test_gaussian_mixture_masks.py index 40b51d457..db687000b 100644 --- a/tests/transforms/test_gaussian_mixture_masks.py +++ b/tests/transforms/test_gaussian_mixture_masks.py @@ -7,4 +7,4 @@ def test() -> None: transform = GaussianMixtureMask(20, (10, 15)) image = torch.rand(3, 32, 17) output = transform(image) - assert output.shape == (3, 32, 32) + assert output.shape == image.shape