diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 338def43a..6794c146f 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -9,6 +9,7 @@ # All Rights Reserved from lightly.transforms.aim_transform import AIMTransform +from lightly.transforms.amplitude_rescale_transform import AmplitudeRescaleTranform from lightly.transforms.byol_transform import ( BYOLTransform, BYOLView1Transform, diff --git a/lightly/transforms/amplitude_rescale_transform.py b/lightly/transforms/amplitude_rescale_transform.py new file mode 100644 index 000000000..e09128ced --- /dev/null +++ b/lightly/transforms/amplitude_rescale_transform.py @@ -0,0 +1,34 @@ +from typing import Tuple + +import numpy as np +import torch +from torch import Tensor +from torch.distributions import Uniform + + +class AmplitudeRescaleTranform: + """Implementation of amplitude rescaling transformation. + + This transform will rescale the amplitude of the Fourier Spectrum (`freq_image`) of the image and return it. + + Attributes: + dist: + Uniform distribution in `[m, n)` from which the scaling value will be selected. + """ + + def __init__(self, range: Tuple[float, float] = (0.8, 1.75)) -> None: + self.dist = Uniform(range[0], range[1]) + + def __call__(self, freq_image: Tensor) -> Tensor: + amplitude = torch.sqrt(freq_image.real**2 + freq_image.imag**2) + + phase = torch.atan2(freq_image.imag, freq_image.real) + # p with shape (H, W) + p = self.dist.sample(freq_image.shape[1:]).to(freq_image.device) + # Unsqueeze to add channel dimension. + amplitude *= p.unsqueeze(0) + real = amplitude * torch.cos(phase) + imag = amplitude * torch.sin(phase) + output = torch.complex(real, imag) + + return output diff --git a/tests/transforms/test_amplitude_rescale_transform.py b/tests/transforms/test_amplitude_rescale_transform.py new file mode 100644 index 000000000..0b8d38eaa --- /dev/null +++ b/tests/transforms/test_amplitude_rescale_transform.py @@ -0,0 +1,26 @@ +import numpy as np +import torch + +from lightly.transforms import ( + AmplitudeRescaleTranform, + IRFFT2DTransform, + RFFT2DTransform, +) + + +# Testing function image -> FFT -> AmplitudeRescale. +# Compare shapes of source and result. +def test() -> None: + image = torch.randn(3, 64, 64) + + rfftTransform = RFFT2DTransform() + rfft = rfftTransform(image) + + ampRescaleTf_1 = AmplitudeRescaleTranform() + rescaled_rfft_1 = ampRescaleTf_1(rfft) + + ampRescaleTf_2 = AmplitudeRescaleTranform(range=(1.0, 2.0)) + rescaled_rfft_2 = ampRescaleTf_2(rfft) + + assert rescaled_rfft_1.shape == rfft.shape + assert rescaled_rfft_2.shape == rfft.shape