forked from irhum/R2Plus1D-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
module.py
65 lines (53 loc) · 3.21 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import math
import torch.nn as nn
from torch.nn.modules.utils import _triple
class SpatioTemporalConv(nn.Module):
r"""Applies a factored 3D convolution over an input signal composed of several input
planes with distinct spatial and time axes, by performing a 2D convolution over the
spatial axes to an intermediate subspace, followed by a 1D convolution over the time
axis to produce the final output.
Args:
in_channels (int): Number of channels in the input tensor
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to the sides of the input during their respective convolutions. Default: 0
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
super(SpatioTemporalConv, self).__init__()
# if ints are entered, convert them to iterables, 1 -> [1, 1, 1]
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
# decomposing the parameters into spatial and temporal components by
# masking out the values with the defaults on the axis that
# won't be convolved over. This is necessary to avoid unintentional
# behavior such as padding being added twice
spatial_kernel_size = [1, kernel_size[1], kernel_size[2]]
spatial_stride = [1, stride[1], stride[2]]
spatial_padding = [0, padding[1], padding[2]]
temporal_kernel_size = [kernel_size[0], 1, 1]
temporal_stride = [stride[0], 1, 1]
temporal_padding = [padding[0], 0, 0]
# compute the number of intermediary channels (M) using formula
# from the paper section 3.5
intermed_channels = int(math.floor((kernel_size[0] * kernel_size[1] * kernel_size[2] * in_channels * out_channels)/ \
(kernel_size[1]* kernel_size[2] * in_channels + kernel_size[0] * out_channels)))
# the spatial conv is effectively a 2D conv due to the
# spatial_kernel_size, followed by batch_norm and ReLU
self.spatial_conv = nn.Conv3d(in_channels, intermed_channels, spatial_kernel_size,
stride=spatial_stride, padding=spatial_padding, bias=bias)
self.bn = nn.BatchNorm3d(intermed_channels)
self.relu = nn.ReLU()
# the temporal conv is effectively a 1D conv, but has batch norm
# and ReLU added inside the model constructor, not here. This is an
# intentional design choice, to allow this module to externally act
# identical to a standard Conv3D, so it can be reused easily in any
# other codebase
self.temporal_conv = nn.Conv3d(intermed_channels, out_channels, temporal_kernel_size,
stride=temporal_stride, padding=temporal_padding, bias=bias)
def forward(self, x):
x = self.relu(self.bn(self.spatial_conv(x)))
x = self.temporal_conv(x)
return x