Skip to content

Commit

Permalink
Add AmplitudeRescaleTransform (#1694)
Browse files Browse the repository at this point in the history
  • Loading branch information
payo101 authored Oct 21, 2024
1 parent 5ac3898 commit 9578268
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
1 change: 1 addition & 0 deletions lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# All Rights Reserved

from lightly.transforms.aim_transform import AIMTransform
from lightly.transforms.amplitude_rescale_transform import AmplitudeRescaleTranform
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
Expand Down
34 changes: 34 additions & 0 deletions lightly/transforms/amplitude_rescale_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Tuple

import numpy as np
import torch
from torch import Tensor
from torch.distributions import Uniform


class AmplitudeRescaleTranform:
"""Implementation of amplitude rescaling transformation.
This transform will rescale the amplitude of the Fourier Spectrum (`freq_image`) of the image and return it.
Attributes:
dist:
Uniform distribution in `[m, n)` from which the scaling value will be selected.
"""

def __init__(self, range: Tuple[float, float] = (0.8, 1.75)) -> None:
self.dist = Uniform(range[0], range[1])

def __call__(self, freq_image: Tensor) -> Tensor:
amplitude = torch.sqrt(freq_image.real**2 + freq_image.imag**2)

phase = torch.atan2(freq_image.imag, freq_image.real)
# p with shape (H, W)
p = self.dist.sample(freq_image.shape[1:]).to(freq_image.device)
# Unsqueeze to add channel dimension.
amplitude *= p.unsqueeze(0)
real = amplitude * torch.cos(phase)
imag = amplitude * torch.sin(phase)
output = torch.complex(real, imag)

return output
26 changes: 26 additions & 0 deletions tests/transforms/test_amplitude_rescale_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import torch

from lightly.transforms import (
AmplitudeRescaleTranform,
IRFFT2DTransform,
RFFT2DTransform,
)


# Testing function image -> FFT -> AmplitudeRescale.
# Compare shapes of source and result.
def test() -> None:
image = torch.randn(3, 64, 64)

rfftTransform = RFFT2DTransform()
rfft = rfftTransform(image)

ampRescaleTf_1 = AmplitudeRescaleTranform()
rescaled_rfft_1 = ampRescaleTf_1(rfft)

ampRescaleTf_2 = AmplitudeRescaleTranform(range=(1.0, 2.0))
rescaled_rfft_2 = ampRescaleTf_2(rfft)

assert rescaled_rfft_1.shape == rfft.shape
assert rescaled_rfft_2.shape == rfft.shape

0 comments on commit 9578268

Please sign in to comment.