-
Notifications
You must be signed in to change notification settings - Fork 28
/
normalize_layer.cu
203 lines (180 loc) · 8.43 KB
/
normalize_layer.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#include <algorithm>
#include <cfloat>
#include <vector>
#include "thrust/device_vector.h"
#include "caffe/layer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/layers/normalize_layer.hpp"
namespace caffe {
template <typename Dtype>
__global__ void kernel_channel_sum(const int num, const int channels, const int spatial_dim, Dtype epsilon,
const Dtype* data, Dtype* norm_data) {
CUDA_KERNEL_LOOP(index, num * spatial_dim) {
int n = index / spatial_dim;
int s = index % spatial_dim;
Dtype sum = 0;
for (int c = 0; c < channels; ++c) {
sum += data[(n * channels + c) * spatial_dim + s];
}
norm_data[index] = sum + epsilon;
}
}
template <typename Dtype>
__global__ void kernel_channel_scale(const int num, const int channels, const int spatial_dim,
Dtype alpha, const Dtype* data, const Dtype* norm_data,
Dtype beta, Dtype* output_data) {
CUDA_KERNEL_LOOP(index, num * channels * spatial_dim) {
int n = index / channels / spatial_dim;
int s = index % spatial_dim;
output_data[index] = alpha * data[index] * norm_data[n * spatial_dim + s] + beta * output_data[index];
}
}
template <typename Dtype>
__global__ void kernel_channel_self_scale(const int num, const int channels, const int spatial_dim,
const Dtype* norm_data, Dtype* input_output_data) {
CUDA_KERNEL_LOOP(index, num * channels * spatial_dim) {
int n = index / channels / spatial_dim;
int s = index % spatial_dim;
input_output_data[index] *= norm_data[n * spatial_dim + s];
}
}
template <typename Dtype>
__global__ void kernel_channel_div(const int num, const int channels, const int spatial_dim,
Dtype alpha, const Dtype* data, const Dtype* norm_data,
Dtype beta, Dtype* output_data) {
CUDA_KERNEL_LOOP(index, num * channels * spatial_dim) {
int n = index / channels / spatial_dim;
int s = index % spatial_dim;
output_data[index] = alpha * data[index] / norm_data[n * spatial_dim + s] + beta * output_data[index];
}
}
template <typename Dtype>
__global__ void kernel_channel_self_div(const int num, const int channels, const int spatial_dim,
const Dtype* norm_data, Dtype* input_output_data) {
CUDA_KERNEL_LOOP(index, num * channels * spatial_dim) {
int n = index / channels / spatial_dim;
int s = index % spatial_dim;
input_output_data[index] /= norm_data[n * spatial_dim + s];
}
}
template <typename Dtype>
__global__ void kernel_channel_dot(const int num, const int channels,
const int spatial_dim, const Dtype* data_1, const Dtype* data_2,
Dtype* channel_dot) {
CUDA_KERNEL_LOOP(index, num * spatial_dim) {
int n = index / spatial_dim;
int s = index % spatial_dim;
Dtype dot = 0;
for (int c = 0; c < channels; ++c) {
dot += (data_1[(n * channels + c) * spatial_dim + s]
* data_2[(n * channels + c) * spatial_dim + s]);
}
channel_dot[index] = dot;
}
}
template <typename Dtype>
__global__ void kernel_sign(const int count, const Dtype* input, Dtype* sign_out) {
CUDA_KERNEL_LOOP(index, count) {
sign_out[index] = (Dtype(0) < input[index]) - (input[index] < Dtype(0));
}
}
template <typename Dtype>
void NormalizeLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
Dtype* square_data = squared_.mutable_gpu_data();
Dtype* norm_data = (top.size() == 2) ? top[1]->mutable_gpu_data() : norm_.mutable_gpu_data();
int num = bottom[0]->num();
int channels = bottom[0]->channels();
int spatial_dim = bottom[0]->height() * bottom[0]->width();
if (normalize_type_ == "L2") {
caffe_gpu_powx(num*channels*spatial_dim, bottom_data, Dtype(2), square_data);
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_sum<Dtype> << <CAFFE_GET_BLOCKS(num*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, 1e-12, square_data, norm_data);
caffe_gpu_powx(num * spatial_dim, norm_data, Dtype(0.5), norm_data);
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_div<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, Dtype(1), bottom_data, norm_data, Dtype(0), top_data);
}
else if (normalize_type_ == "L1") {
caffe_gpu_abs(num*channels*spatial_dim, bottom_data, square_data);
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_sum<Dtype> << <CAFFE_GET_BLOCKS(num*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, 1e-6, square_data, norm_data);
//caffe_gpu_powx(num * spatial_dim, norm_data, Dtype(-1), norm_data);
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_div<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, Dtype(1), bottom_data, norm_data, Dtype(0), top_data);
}
else {
NOT_IMPLEMENTED;
}
}
template <typename Dtype>
void NormalizeLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* top_data = top[0]->gpu_data();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* square_data = squared_.mutable_gpu_data();
const Dtype* norm_data = (top.size() == 2) ? top[1]->gpu_data() : norm_.gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
Dtype* temp_diff = norm_.mutable_gpu_diff();
int num = top[0]->num();
int channels = top[0]->channels();
int spatial_dim = bottom[0]->height() * bottom[0]->width();
if (propagate_down[0]) {
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_dot<Dtype> << <CAFFE_GET_BLOCKS(num * spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, top_data, top_diff, temp_diff);
if (normalize_type_ == "L2") {
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_scale<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, Dtype(1), top_data, temp_diff, Dtype(0), bottom_diff);
}
else if (normalize_type_ == "L1") {
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_sign<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num*channels*spatial_dim, bottom_data, square_data);
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_scale<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, Dtype(1), square_data, temp_diff, Dtype(0), bottom_diff);
}
else {
NOT_IMPLEMENTED;
}
caffe_gpu_sub(num * channels * spatial_dim, top_diff, bottom_diff, bottom_diff);
if (fix_gradient_) {
//// NOLINT_NEXT_LINE(whitespace/operators)
//kernel_channel_self_scale<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
// CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, norm_data, bottom_diff);
}
else {
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_self_div<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, norm_data, bottom_diff);
}
}
if (bp_norm_) {
const Dtype* norm_diff = top[1]->gpu_diff();
if (normalize_type_ == "L2") {
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_scale<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, Dtype(1), top_data, norm_diff, Dtype(1), bottom_diff);
}
else if (normalize_type_ == "L1") {
if (!propagate_down[0]) {
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_sign<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num*channels*spatial_dim, bottom_data, square_data);
}
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_scale<Dtype> << <CAFFE_GET_BLOCKS(num*channels*spatial_dim),
CAFFE_CUDA_NUM_THREADS >> >(num, channels, spatial_dim, Dtype(1), square_data, norm_diff, Dtype(1), bottom_diff);
}
}
}
INSTANTIATE_LAYER_GPU_FUNCS(NormalizeLayer);
} // namespace caffe