-
Notifications
You must be signed in to change notification settings - Fork 0
/
bn.py
215 lines (180 loc) · 7.54 KB
/
bn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
from collections import OrderedDict, Iterable
from itertools import repeat
try:
# python 3
from queue import Queue
except ImportError:
# python 2
from Queue import Queue
import torch
import torch.nn as nn
import torch.autograd as autograd
from .functions import inplace_abn, inplace_abn_sync
def _pair(x):
if isinstance(x, Iterable):
return x
return tuple(repeat(x, 2))
class ABN(nn.Sequential):
"""Activated Batch Normalization
This gathers a `BatchNorm2d` and an activation function in a single module
"""
def __init__(self, num_features, activation=nn.ReLU(inplace=True), **kwargs):
"""Creates an Activated Batch Normalization module
Parameters
----------
num_features : int
Number of feature channels in the input and output.
activation : nn.Module
Module used as an activation function.
kwargs
All other arguments are forwarded to the `BatchNorm2d` constructor.
"""
super(ABN, self).__init__(OrderedDict([
("bn", nn.BatchNorm2d(num_features, **kwargs)),
("act", activation)
]))
class InPlaceABN(nn.Module):
"""InPlace Activated Batch Normalization"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
"""Creates an InPlace Activated Batch Normalization module
Parameters
----------
num_features : int
Number of feature channels in the input and output.
eps : float
Small constant to prevent numerical issues.
momentum : float
Momentum factor applied to compute running statistics as.
affine : bool
If `True` apply learned scale and shift transformation after normalization.
activation : str
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
slope : float
Negative slope for the `leaky_relu` activation.
"""
super(InPlaceABN, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
self.activation = activation
self.slope = slope
if self.affine:
self.weight = nn.Parameter(torch.Tensor(num_features))
self.bias = nn.Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.fill_(1)
self.bias.data.zero_()
def forward(self, x):
return inplace_abn(x, self.weight, self.bias, autograd.Variable(self.running_mean),
autograd.Variable(self.running_var), self.training, self.momentum, self.eps,
self.activation, self.slope)
def __repr__(self):
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
' affine={affine}, activation={activation}'
if self.activation == "leaky_relu":
rep += ' slope={slope})'
else:
rep += ')'
return rep.format(name=self.__class__.__name__, **self.__dict__)
class InPlaceABNSync(nn.Module):
"""InPlace Activated Batch Normalization with cross-GPU synchronization
This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`.
"""
def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu",
slope=0.01):
"""Creates a synchronized, InPlace Activated Batch Normalization module
Parameters
----------
num_features : int
Number of feature channels in the input and output.
devices : list of int or None
IDs of the GPUs that will run the replicas of this module.
eps : float
Small constant to prevent numerical issues.
momentum : float
Momentum factor applied to compute running statistics as.
affine : bool
If `True` apply learned scale and shift transformation after normalization.
activation : str
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
slope : float
Negative slope for the `leaky_relu` activation.
"""
super(InPlaceABNSync, self).__init__()
self.num_features = num_features
self.devices = devices if devices else list(range(torch.cuda.device_count()))
self.affine = affine
self.eps = eps
self.momentum = momentum
self.activation = activation
self.slope = slope
if self.affine:
self.weight = nn.Parameter(torch.Tensor(num_features))
self.bias = nn.Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
# Initialize queues
self.worker_ids = self.devices[1:]
self.master_queue = Queue(len(self.worker_ids))
self.worker_queues = [Queue(1) for _ in self.worker_ids]
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.fill_(1)
self.bias.data.zero_()
def forward(self, x):
if x.get_device() == self.devices[0]:
# Master mode
extra = {
"is_master": True,
"master_queue": self.master_queue,
"worker_queues": self.worker_queues,
"worker_ids": self.worker_ids
}
else:
# Worker mode
extra = {
"is_master": False,
"master_queue": self.master_queue,
"worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
}
return inplace_abn_sync(x, self.weight, self.bias, autograd.Variable(self.running_mean),
autograd.Variable(self.running_var), extra, self.training, self.momentum, self.eps,
self.activation, self.slope)
def __repr__(self):
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
' affine={affine}, devices={devices}, activation={activation}'
if self.activation == "leaky_relu":
rep += ' slope={slope})'
else:
rep += ')'
return rep.format(name=self.__class__.__name__, **self.__dict__)
class InPlaceABNWrapper(nn.Module):
"""Wrapper module to make `InPlaceABN` compatible with `ABN`"""
def __init__(self, *args, **kwargs):
super(InPlaceABNWrapper, self).__init__()
self.bn = InPlaceABN(*args, **kwargs)
def forward(self, input):
return self.bn(input)
class InPlaceABNSyncWrapper(nn.Module):
"""Wrapper module to make `InPlaceABNSync` compatible with `ABN`"""
def __init__(self, *args, **kwargs):
super(InPlaceABNSyncWrapper, self).__init__()
self.bn = InPlaceABNSync(*args, **kwargs)
def forward(self, input):
return self.bn(input)