From 2c9e7c8c79798affef146fd408176b277c186e8f Mon Sep 17 00:00:00 2001 From: Snehil Chatterjee Date: Mon, 14 Oct 2024 19:37:09 +0530 Subject: [PATCH] removed rfft2d init from GMM --- lightly/transforms/gaussian_mixture_masks_transform.py | 4 ---- tests/transforms/test_gaussian_mixture_masks.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/lightly/transforms/gaussian_mixture_masks_transform.py b/lightly/transforms/gaussian_mixture_masks_transform.py index c9804db3b..b9cc2bb6e 100644 --- a/lightly/transforms/gaussian_mixture_masks_transform.py +++ b/lightly/transforms/gaussian_mixture_masks_transform.py @@ -5,7 +5,6 @@ from torch import Tensor from lightly.transforms.irfft2d_transform import IRFFT2DTransform -from lightly.transforms.rfft2d_transform import RFFT2DTransform class GaussianMixtureMask: @@ -29,9 +28,6 @@ def __init__( num_gaussians: Number of Gaussian kernels to generate in the mixture mask. std_range: Tuple containing the minimum and maximum standard deviation for the Gaussians. """ - - self.rfft2d_transform = RFFT2DTransform() - self.num_gaussians = num_gaussians self.std_range = std_range diff --git a/tests/transforms/test_gaussian_mixture_masks.py b/tests/transforms/test_gaussian_mixture_masks.py index 8f2c8678b..40b51d457 100644 --- a/tests/transforms/test_gaussian_mixture_masks.py +++ b/tests/transforms/test_gaussian_mixture_masks.py @@ -5,6 +5,6 @@ def test() -> None: transform = GaussianMixtureMask(20, (10, 15)) - image = torch.rand(3, 32, 32) + image = torch.rand(3, 32, 17) output = transform(image) assert output.shape == (3, 32, 32)