forked from huggingface/torchMoji
-
Notifications
You must be signed in to change notification settings - Fork 5
/
attlayer.py
68 lines (52 loc) · 2.56 KB
/
attlayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# -*- coding: utf-8 -*-
""" Define the Attention Layer of the model.
"""
from __future__ import print_function, division
import torch
from torch.autograd import Variable
from torch.nn import Module
from torch.nn.parameter import Parameter
class Attention(Module):
"""
Computes a weighted average of the different channels across timesteps.
Uses 1 parameter pr. channel to compute the attention value for a single timestep.
"""
def __init__(self, attention_size, return_attention=False):
""" Initialize the attention layer
# Arguments:
attention_size: Size of the attention vector.
return_attention: If true, output will include the weight for each input token
used for the prediction
"""
super(Attention, self).__init__()
self.return_attention = return_attention
self.attention_size = attention_size
self.attention_vector = Parameter(torch.FloatTensor(attention_size))
self.attention_vector.data.normal_(std=0.05) # Initialize attention vector
def __repr__(self):
s = '{name}({attention_size}, return attention={return_attention})'
return s.format(name=self.__class__.__name__, **self.__dict__)
def forward(self, inputs, input_lengths):
""" Forward pass.
# Arguments:
inputs (Torch.Variable): Tensor of input sequences
input_lengths (torch.LongTensor): Lengths of the sequences
# Return:
Tuple with (representations and attentions if self.return_attention else None).
"""
logits = inputs.matmul(self.attention_vector)
unnorm_ai = (logits - logits.max()).exp()
# Compute a mask for the attention on the padded sequences
# See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5
max_len = unnorm_ai.size(1)
idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0)
mask = Variable((idxes < input_lengths.unsqueeze(1)).float())
# apply mask and renormalize attention scores (weights)
masked_weights = unnorm_ai * mask
att_sums = masked_weights.sum(dim=1, keepdim=True) # sums per sequence
attentions = masked_weights.div(att_sums)
# apply attention weights
weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs))
# get the final fixed vector representations of the sentences
representations = weighted.sum(dim=1)
return (representations, attentions if self.return_attention else None)