Skip to content

Commit

Permalink
GMM changed output to frequency domain output
Browse files Browse the repository at this point in the history
  • Loading branch information
snehilchatterjee committed Oct 15, 2024
1 parent 2c9e7c8 commit b68b473
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 27 deletions.
38 changes: 12 additions & 26 deletions lightly/transforms/gaussian_mixture_masks_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
import torch.fft
from torch import Tensor

from lightly.transforms.irfft2d_transform import IRFFT2DTransform


class GaussianMixtureMask:
"""Applies a Gaussian Mixture Mask in the Fourier domain to a single-channel image.
"""Applies a Gaussian Mixture Mask in the Fourier domain to an image.
The mask is created using random Gaussian kernels, which are applied in
the frequency domain via RFFT2D, and then the IRFFT2D is used to return
to the spatial domain. The transformation is applied to each image channel separately.
the frequency domain.
Attributes:
num_gaussians: Number of Gaussian kernels to generate in the mixture mask.
Expand Down Expand Up @@ -42,7 +39,7 @@ def gaussian_kernel(
center: Tensor specifying the center of the Gaussian kernel.
Returns:
Tensor: A 2D Gaussian kernel.
A 2D Gaussian kernel tensor.
"""
u, v = torch.meshgrid(torch.arange(0, size[0]), torch.arange(0, size[1]))
u = u.to(sigma.device)
Expand All @@ -65,32 +62,22 @@ def apply_gaussian_mixture_mask(
std: Tuple specifying the standard deviation range for the Gaussians.
Returns:
Tensor: Image after applying the Gaussian mixture mask.
Image tensor in frequency domain after applying the Gaussian mixture mask.
"""
image_size = freq_image.shape[1:]
original_height = image_size[0]
original_width = 2 * (image_size[1] - 1)

original_shape = (original_height, original_width)

self.irfft2d_transform = IRFFT2DTransform(original_shape)

size = freq_image[0].shape

(C, U, V) = freq_image.shape
mask = freq_image.new_ones(freq_image.shape)

for _ in range(num_gaussians):
u0 = torch.randint(0, size[0], (1,), device=freq_image.device)
v0 = torch.randint(0, size[1], (1,), device=freq_image.device)
u0 = torch.randint(0, U, (1,), device=freq_image.device)
v0 = torch.randint(0, V, (1,), device=freq_image.device)
center = torch.tensor((u0, v0), device=freq_image.device)
sigma = torch.rand(2, device=freq_image.device) * (std[1] - std[0]) + std[0]

g_kernel = self.gaussian_kernel((size[0], size[1]), sigma, center)
mask -= g_kernel
g_kernel = self.gaussian_kernel((U, V), sigma, center)
mask *= 1 - g_kernel.unsqueeze(0)

filtered_freq_image = freq_image * mask
filtered_image = self.irfft2d_transform(filtered_freq_image).abs()
return filtered_image
return filtered_freq_image

def __call__(self, freq_image: Tensor) -> Tensor:
"""Applies the Gaussian mixture mask transformation to the input frequency-domain image.
Expand All @@ -99,9 +86,8 @@ def __call__(self, freq_image: Tensor) -> Tensor:
freq_image: Tensor representing a frequency-domain image of shape (C, H, W//2+1).
Returns:
Tensor: The transformed image after applying the Gaussian mixture mask.
Image tensor in frequency domain after applying the Gaussian mixture mask.
"""
transformed_channel: Tensor = self.apply_gaussian_mixture_mask(
return self.apply_gaussian_mixture_mask(
freq_image, self.num_gaussians, self.std_range
)
return transformed_channel
2 changes: 1 addition & 1 deletion tests/transforms/test_gaussian_mixture_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def test() -> None:
transform = GaussianMixtureMask(20, (10, 15))
image = torch.rand(3, 32, 17)
output = transform(image)
assert output.shape == (3, 32, 32)
assert output.shape == image.shape

0 comments on commit b68b473

Please sign in to comment.