-
Notifications
You must be signed in to change notification settings - Fork 1
/
specpool2d.py
121 lines (104 loc) · 4.85 KB
/
specpool2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
from torch.autograd import Function
import math
from torch.nn.modules.utils import _pair
def _spectral_crop(input, oheight, owidth):
cutoff_freq_h = math.ceil(oheight / 2)
cutoff_freq_w = math.ceil(owidth / 2)
if oheight % 2 == 1:
if owidth % 2 == 1:
top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
bottom_right = input[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
else:
top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
bottom_right = input[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
else:
if owidth % 2 == 1:
top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
bottom_right = input[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
else:
top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
bottom_right = input[:, :, -cutoff_freq_h:, -cutoff_freq_w:]
top_combined = torch.cat((top_left, top_right), dim=-1)
bottom_combined = torch.cat((bottom_left, bottom_right), dim=-1)
all_together = torch.cat((top_combined, bottom_combined), dim=-2)
return all_together
def _spectral_pad(input, output, oheight, owidth):
cutoff_freq_h = math.ceil(oheight / 2)
cutoff_freq_w = math.ceil(owidth / 2)
pad = torch.zeros_like(input)
if oheight % 2 == 1:
if owidth % 2 == 1:
pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
pad[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):] = output[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
else:
pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
pad[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:] = output[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
else:
if owidth % 2 == 1:
pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
pad[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):] = output[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
else:
pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
pad[:, :, -cutoff_freq_h:, -cutoff_freq_w:] = output[:, :, -cutoff_freq_h:, -cutoff_freq_w:]
return pad
def DiscreteHartleyTransform(input):
# fft = torch.rfft(input, 2, normalized=True, onesided=False)
# for new version of pytorch
fft = torch.fft.fft2(input, dim=(-2, -1), norm='ortho')
fft = torch.stack((fft.real, fft.imag), -1)
dht = fft[:, :, :, :, -2] - fft[:, :, :, :, -1]
return dht
class SpectralPoolingFunction(Function):
@staticmethod
def forward(ctx, input, oheight, owidth):
ctx.oh = oheight
ctx.ow = owidth
ctx.save_for_backward(input)
# Hartley transform by RFFT
dht = DiscreteHartleyTransform(input)
# frequency cropping
all_together = _spectral_crop(dht, oheight, owidth)
# inverse Hartley transform
dht = DiscreteHartleyTransform(all_together)
return dht
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_variables
# Hartley transform by RFFT
dht = DiscreteHartleyTransform(grad_output)
# frequency padding
grad_input = _spectral_pad(input, dht, ctx.oh, ctx.ow)
# inverse Hartley transform
grad_input = DiscreteHartleyTransform(grad_input)
return grad_input, None, None
class SpectralPool2d(nn.Module):
def __init__(self, scale_factor):
super(SpectralPool2d, self).__init__()
self.scale_factor = _pair(scale_factor)
def forward(self, input):
H, W = input.size(-2), input.size(-1)
h, w = math.ceil(H*self.scale_factor[0]), math.ceil(W*self.scale_factor[1])
return SpectralPoolingFunction.apply(input, h, w)
if __name__ == '__main__':
input = torch.randn(4, 1, 100, 64)
layer = SpectralPool2d(scale_factor=(0.1, 1))
out = layer(input)
pass