-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
245 lines (204 loc) · 10.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import torch
import torch.nn as nn
import torch.nn.functional as f
from collections import defaultdict
class Vocab:
def __init__(self, vocabs):
self._v2i = defaultdict(self._default_idx)
v2i = {}
for idx, vocab in enumerate(vocabs):
v2i[vocab] = idx
self._v2i.update(v2i)
self._i2v = defaultdict(self._default_vocab)
self._i2v.update({v: k for k, v in self._v2i.items()})
self.default_idx = self._v2i['<unk>']
self.len = len(self._v2i)
assert len(self._v2i) == len(self._i2v)
def _default_idx(self):
return self.default_idx
def _default_vocab(self):
return '<unk>'
def to_string(self, indices):
# indices is a tensor
str = ''
for idx in indices:
str += self._i2v[idx.item()]
return str
def to_string_list(self, indices):
# indices is a tensor
l = []
for idx in indices:
l.append(self._i2v[idx.item()])
return l
def __len__(self):
# ensure length is static because defaultdict inserts unseen keys upon first access
return self.len
class Embedding(nn.Module):
"""Feature extraction:
Looks up embeddings for a source character and the language,
concatenates the two,
then passes through a linear layer
"""
def __init__(self, embedding_dim, langs, C2I):
super(Embedding, self).__init__()
self.langs = langs
self.char_embeddings = nn.Embedding(len(C2I), embedding_dim)
self.lang_embeddings = nn.Embedding(len(self.langs), embedding_dim)
# map concatenated source and language embedding to 1 embedding
self.fc = nn.Linear(2 * embedding_dim, embedding_dim)
def forward(self, char_indices, lang_indices):
# both result in (L, E), where L is the length of the entire cognate set
chars_embedded = self.char_embeddings(char_indices)
lang_embedded = self.lang_embeddings(lang_indices)
# concatenate the tensors to form one long embedding then map down to regular embedding size
return self.fc(torch.cat((chars_embedded, lang_embedded), dim=-1))
class MLP(nn.Module):
"""
Multi-layer perceptron to generate logits from the decoder state
"""
def __init__(self, hidden_dim, feedforward_dim, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(hidden_dim, 2 * feedforward_dim)
self.fc2 = nn.Linear(2 * feedforward_dim, feedforward_dim)
self.fc3 = nn.Linear(feedforward_dim, output_size, bias=False)
# no need to perform softmax because CrossEntropyLoss does the softmax for you
def forward(self, decoder_state):
h = f.relu(self.fc1(decoder_state))
scores = self.fc3(f.relu(self.fc2(h)))
return scores
class Attention(nn.Module):
def __init__(self, hidden_dim, embedding_dim):
super(Attention, self).__init__()
self.W_query = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.W_key = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.W_c_s = nn.Linear(embedding_dim, hidden_dim, bias=False)
def forward(self, query, keys, encoded_input):
# query: decoder state. [1, 1, H]
# keys: encoder states. [1, L, H]
query = self.W_query(query)
# dot product attention to calculate similarity between the query and each key
# scores: [1, L, 1]
scores = torch.matmul(keys, query.transpose(1, 2))
# softmax to get a probability distribution over the L encoder states
weights = f.softmax(scores, dim=-2)
# weights: L x 1
# encoded_input: L x E
# keys: L x D
# result: 1 x D - weighted version of the input
weighted_states = weights * (self.W_c_s(encoded_input) + self.W_key(keys))
weighted_states = weighted_states.sum(dim=-2)
return weighted_states
class Model(nn.Module):
"""
Encoder-decoder architecture
"""
def __init__(self, C2I,
num_layers,
dropout,
feedforward_dim,
embedding_dim,
model_size,
model_type,
langs):
super(Model, self).__init__()
self.C2I = C2I
# share embedding across all languages, including the proto-language
self.embeddings = Embedding(embedding_dim, langs, C2I)
# have separate embedding for the language
# technically, most of the vocab is not used in the separator embedding
# since the separator tokens and the character embeddings are disjoint, put them all in the same matrix
self.langs = langs
self.protolang = langs[0]
self.L2I = {l: idx for idx, l in enumerate(langs)}
self.dropout = nn.Dropout(dropout)
if model_type == "gru":
self.encoder_rnn = nn.GRU(input_size=embedding_dim,
hidden_size=model_size,
num_layers=num_layers,
batch_first=True)
self.decoder_rnn = nn.GRU(input_size=embedding_dim + model_size,
hidden_size=model_size,
num_layers=num_layers,
batch_first=True)
self.mlp = MLP(hidden_dim=model_size, feedforward_dim=feedforward_dim, output_size=len(C2I))
self.attention = Attention(hidden_dim=model_size, embedding_dim=embedding_dim)
def forward(self, source_tokens, source_langs, target_tokens, target_langs, device):
# encoder
# encoder_states: 1 x L x H, memory: 1 x 1 x H, where L = len(daughter_forms)
(encoder_states, memory), embedded_cognateset = self.encode(source_tokens, source_langs, device)
# perform dropout on the output of the RNN
encoder_states = self.dropout(encoder_states)
# decoder
# start of protoform sequence
start_char = (torch.tensor([self.C2I._v2i["<s>"]]).to(device), torch.tensor([self.L2I["sep"]]).to(device))
start_encoded = self.embeddings(*start_char)
# initialize weighted states to the final encoder state
attention_weighted_states = memory.squeeze(dim=0)
# start_encoded: 1 x E, attention_weighted_states: 1 x H
# concatenated into 1 x (H + E)
decoder_input = torch.cat((start_encoded, attention_weighted_states), dim=1).unsqueeze(dim=0)
# perform dropout on the input to the RNN
decoder_input = self.dropout(decoder_input)
decoder_state, _ = self.decoder_rnn(decoder_input)
# perform dropout on the output of the RNN
decoder_state = self.dropout(decoder_state)
scores = []
for lang, char in zip(target_langs, target_tokens):
# lang will either be sep or the protolang
# embedding layer
true_char_embedded = self.embeddings(char, lang).unsqueeze(dim=0)
# MLP to get a probability distribution over the possible output phonemes
char_scores = self.mlp(decoder_state + attention_weighted_states)
scores.append(char_scores.squeeze(dim=0))
# dot product attention over the encoder states - results in (1, H)
attention_weighted_states = self.attention(decoder_state, encoder_states, embedded_cognateset)
# decoder_input: (1, 1, H + E)
decoder_input = torch.cat((true_char_embedded, attention_weighted_states), dim=1).unsqueeze(dim=0)
# perform dropout on the input to the RNN
decoder_input = self.dropout(decoder_input)
decoder_state, _ = self.decoder_rnn(decoder_input)
# perform dropout on the output of the RNN
decoder_state = self.dropout(decoder_state)
# |T| elem list with (1, |Y|) -> (T, |Y|)
scores = torch.vstack(scores)
return scores
def encode(self, source_tokens, source_langs, device):
# daughter_forms: list of lang and indices in the vocab
embedded_cognateset = self.embeddings(source_tokens, source_langs).to(device)
# batch size of 1
embedded_cognateset = embedded_cognateset.unsqueeze(dim=0)
# perform dropout on the input to the RNN
embedded_cognateset = self.dropout(embedded_cognateset)
return self.encoder_rnn(embedded_cognateset), embedded_cognateset
def decode(self, encoder_states, memory, embedded_cognateset, max_length, device):
# greedy decoding - generate protoform by picking most likely sequence at each time step
start_char = (torch.tensor([self.C2I._v2i["<s>"]]).to(device), torch.tensor([self.L2I["sep"]]).to(device))
start_encoded = self.embeddings(*start_char).to(device)
# initialize weighted states to the final encoder state
attention_weighted_states = memory.squeeze(dim=0)
# start_encoded: 1 x E, attention_weighted_states: 1 x H
# concatenated into 1 x (H + E)
decoder_input = torch.cat((start_encoded, attention_weighted_states), dim=1).unsqueeze(dim=0)
decoder_state, _ = self.decoder_rnn(decoder_input)
reconstruction = []
i = 0
while i < max_length:
# embedding layer
# MLP to get a probability distribution over the possible output phonemes
char_scores = self.mlp(decoder_state + attention_weighted_states)
# char_scores: [1, 1, |Y|]
predicted_char = torch.argmax(char_scores.squeeze(dim=0)).item()
predicted_char_idx = predicted_char
predicted_char = (torch.tensor([predicted_char]).to(device), torch.tensor([self.L2I[self.protolang]]).to(device))
predicted_char_embedded = self.embeddings(*predicted_char)
# dot product attention over the encoder states
attention_weighted_states = self.attention(decoder_state, encoder_states, embedded_cognateset)
# (1, 1, H + E)
decoder_input = torch.cat((predicted_char_embedded, attention_weighted_states), dim=1).unsqueeze(dim=0)
decoder_state, _ = self.decoder_rnn(decoder_input)
reconstruction.append(predicted_char_idx)
i += 1
# end of sequence generated
if predicted_char_idx == self.C2I._v2i[">"]:
break
return torch.tensor(reconstruction)