-
Notifications
You must be signed in to change notification settings - Fork 7
/
targetEncoder.py
131 lines (102 loc) · 5.13 KB
/
targetEncoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import numpy as np
import torch.nn as nn
import math
import torchsummary
from config import config
class SkipConnection(nn.Module):
def __init__(self, module):
super(SkipConnection, self).__init__()
self.module = module
def forward(self, inputs):
return inputs + self.module(inputs)
class MultiHeadAttention(nn.Module):
def __init__(self, cfg,n_heads = 8):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.input_dim = cfg.embedding_dim
self.embedding_dim = cfg.embedding_dim
self.value_dim = self.embedding_dim // self.n_heads
self.key_dim = self.value_dim
self.norm_factor = 1 / math.sqrt(self.key_dim)
self.w_query = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
self.w_key = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.key_dim))
self.w_value = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.value_dim))
self.w_out = nn.Parameter(torch.Tensor(self.n_heads, self.value_dim, self.embedding_dim))
self.init_parameters()
def init_parameters(self):
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, q, h=None, mask=None):
"""
:param q: queries (batch_size, n_query, input_dim)
:param h: data (batch_size, graph_size, input_dim)
:param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
:return:
"""
if h is None:
h = q
batch_size, target_size, input_dim = h.size()
n_query = q.size(1) # n_query = target_size in tsp
assert q.size(0) == batch_size
assert q.size(2) == input_dim
assert input_dim == self.input_dim
h_flat = h.contiguous().view(-1, input_dim) # (batch_size*graph_size)*input_dim
q_flat = q.contiguous().view(-1, input_dim) # (batch_size*n_query)*input_dim
shape_v = (self.n_heads, batch_size, target_size, -1)
shape_k = (self.n_heads, batch_size, target_size, -1)
shape_q = (self.n_heads, batch_size, n_query, -1)
Q = torch.matmul(q_flat, self.w_query).view(shape_q) # n_heads*batch_size*n_query*key_dim
K = torch.matmul(h_flat, self.w_key).view(shape_k) # n_heads*batch_size*targets_size*key_dim
V = torch.matmul(h_flat, self.w_value).view(shape_v) # n_heads*batch_size*targets_size*value_dim
U = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) # n_heads*batch_size*n_query*targets_size
if mask is not None:
mask = mask.view(1, batch_size, n_query, target_size).expand_as(U) # copy for n_heads times
U[mask] = -np.inf # ??
attention = torch.softmax(U, dim=-1) # n_heads*batch_size*n_query*targets_size
if mask is not None:
attnc = attention.clone()
attnc[mask] = 0
attention = attnc
heads = torch.matmul(attention, V) # n_heads*batch_size*n_query*value_dim
out = torch.mm(
heads.permute(1, 2, 0, 3).reshape(-1, self.n_heads * self.value_dim),
# batch_size*n_query*n_heads*value_dim
self.w_out.view(-1, self.embedding_dim)
# n_heads*value_dim*embedding_dim
).view(batch_size, n_query, self.embedding_dim)
return out # batch_size*n_query*embedding_dim
class Normalization(nn.Module):
def __init__(self, cfg):
super(Normalization, self).__init__()
self.normalizer = nn.LayerNorm(cfg.embedding_dim, elementwise_affine=True)
def forward(self, input):
return self.normalizer(input.view(-1, input.size(-1))).view(*input.size())
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, cfg):
super(MultiHeadAttentionLayer, self).__init__()
self.layer = nn.Sequential(
SkipConnection(nn.Sequential(Normalization(cfg=cfg), MultiHeadAttention(cfg))),
SkipConnection(nn.Sequential(Normalization(cfg=cfg), nn.Linear(cfg.embedding_dim, 512),
nn.ReLU(inplace=True),
nn.Linear(512, cfg.embedding_dim)))
)
def forward(self, x):
x = self.layer(x)
return x
class TargetEncoder(nn.Module):
def __init__(self, cfg):
super(TargetEncoder, self).__init__()
self.init_depot_embedding = nn.Linear(2,cfg.embedding_dim)
self.init_embedding = nn.Linear(2, cfg.embedding_dim)
self.layers = nn.Sequential(*(
MultiHeadAttentionLayer(cfg=cfg)
for _ in range(3)
)) # * can capture the elements in the iteration
def forward(self, depot_input, city_input, mask=None):
assert mask is None, "mask is None"
h = torch.cat([self.init_depot_embedding(depot_input),self.init_embedding(city_input)],dim=1)
h = self.layers(h)
return h # batch_size*targets_size*embedding_dim