-
Notifications
You must be signed in to change notification settings - Fork 0
/
RNN.py
95 lines (74 loc) · 2.69 KB
/
RNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
import torch.nn.functional as F
import fasttext_loader
class BiGRU(nn.Module):
"""
BiGRU: bi-directional GRU
"""
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.fw_gru = nn.GRUCell(input_dim, hidden_dim) # Forward
self.bw_gru = nn.GRUCell(input_dim, hidden_dim) # Backward
def forward(self, x):
time = x.size(1) # x = [batch_size x time x input]
hx_1 = x.new(x.size(0), self.hidden_dim).zero_()
for i in range(time):
hx_1 = self.fw_gru(x[:, i], hx_1)
hx_2 = x.new(x.size(0), self.hidden_dim).zero_()
for i in reversed(range(time)):
hx_2 = self.bw_gru(x[:, i], hx_2)
# Concatenate forward and backward representations
h = torch.cat([hx_1, hx_2], dim=1)
return h
class RNN(nn.Module):
"""
RNN: single-layer, bi-directional GRU
"""
def __init__(
self, vocab_size, emb_dim, hidden_dim, dropout_prob,
padding_idx, num_classes, id2tok):
super().__init__()
# Embedding layer
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=emb_dim,
padding_idx=padding_idx,
)
# Load and copy pre-trained embedding weights
weights = fasttext_loader.create_weights(id2tok)
self.embedding.weight.data.copy_(torch.from_numpy(weights))
self.embedding.weight.requires_grad = False # Freeze embeddings
# Bi-directional GRU layer
self.bigru = BiGRU(input_dim=emb_dim, hidden_dim=hidden_dim)
# Dropout layer
self.dropout = nn.Dropout(p=dropout_prob)
# Fully connected layers
self.fc1 = nn.Linear(2 * 2 * hidden_dim, hidden_dim) # hx+hy and fw+bw
self.fc2 = nn.Linear(hidden_dim, num_classes)
def forward(self, p, h):
"""
Forward pass
1. Embed premises
2. Embed hypotheses
3. Concatenate encoded sentences
@param p: premises
@param h: hypotheses
"""
# Embed and encode premises
x = self.embedding(p)
hx = self.bigru(x)
hx = self.dropout(hx) # Dropout regularization
# Embed and encode hypotheses
y = self.embedding(h)
hy = self.bigru(y)
hy = self.dropout(hy) # Dropout regularization
# Concatenate sentence representations
h = torch.cat((hx, hy), dim=1)
# Feed concatenation through fully-connected layers
h = F.relu(self.fc1(h))
h = self.dropout(h) # Dropout regularization
h = self.fc2(h)
return h