-
Notifications
You must be signed in to change notification settings - Fork 1
/
modules.py
128 lines (109 loc) · 5.15 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import pdb
# https://github.com/keitakurita/Better_LSTM_PyTorch/blob/master/better_lstm/model.py
class VariationalDropout(nn.Module):
"""
Applies the same dropout mask across the temporal dimension
See https://arxiv.org/abs/1512.05287 for more details.
Note that this is not applied to the recurrent activations in the LSTM like the above paper.
Instead, it is applied to the inputs and outputs of the recurrent layer.
"""
def __init__(self, dropout, batch_first):
super().__init__()
self.dropout = dropout
self.batch_first = batch_first
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training or self.dropout <= 0.:
return x
is_packed = isinstance(x, rnn.PackedSequence)
if is_packed:
x, batch_sizes = rnn.pad_packed_sequence(x)
max_batch_size = x.size(1)
else:
batch_sizes = None
max_batch_size = x.size(1)
# Drop same mask across entire sequence
if self.batch_first:
m = x.new_empty(max_batch_size, 1, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout)
else:
m = x.new_empty(1, max_batch_size, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout)
x = x.masked_fill(m == 0, 0) / (1 - self.dropout)
if is_packed:
return rnn.pack_padded_sequence(x, batch_sizes, enforce_sorted=False)
else:
return x
class LSTM(nn.LSTM):
def __init__(self, *args, input_drop,
weight_drop, output_drop,
batch_first=False, unit_forget_bias=True, **kwargs):
super().__init__(*args, **kwargs, batch_first=batch_first)
self.unit_forget_bias = unit_forget_bias
self.weight_drop = weight_drop
self.input_drop = VariationalDropout(input_drop,
batch_first=batch_first)
self.output_drop = VariationalDropout(output_drop,
batch_first=batch_first)
self._init_weights()
def _init_weights(self):
"""
Use orthogonal init for recurrent layers, xavier uniform for input layers
Bias is 0 except for forget gate
"""
for name, param in self.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param.data)
elif "weight_ih" in name:
nn.init.xavier_uniform_(param.data)
elif "bias" in name and self.unit_forget_bias:
nn.init.zeros_(param.data)
param.data[self.hidden_size:2 * self.hidden_size] = 1
def _drop_weights(self):
for name, param in self.named_parameters():
if "weight_hh" in name:
getattr(self, name).data = \
torch.nn.functional.dropout(param.data, p=self.weight_drop,
training=self.training).contiguous()
def forward(self, x, hx=None):
self._drop_weights()
self.flatten_parameters()
x = self.input_drop(x)
seq, state = super().forward(x, hx=hx)
return self.output_drop(seq), state
# https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/blob/master/models.py
class Attention(nn.Module):
"""
Attention Network.
"""
def __init__(self, encoder_dim, decoder_dim, attention_dim):
"""
:param encoder_dim: feature size of encoded images
:param decoder_dim: size of decoder's RNN
:param attention_dim: size of the attention network
"""
super(Attention, self).__init__()
self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image
self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output
self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed
# self.relu = nn.ReLU()
self.tanh = nn.Tanh()
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
def forward(self, encoder_out, decoder_hidden):
"""
Forward propagation.
:param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
:param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
:return: attention weighted encoding, weights
"""
att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
att = self.full_att(self.tanh(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels)
alpha = self.softmax(att) # (batch_size, num_pixels)
attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim)
return attention_weighted_encoding, alpha
def global_weight_init(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)