-
Notifications
You must be signed in to change notification settings - Fork 73
/
model.py
73 lines (62 loc) · 3.07 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Callable, List
import math
class Amodel(nn.Module):
def __init__(self, series_dim, feature_dim, target_num, hidden_num, hidden_dim, drop_rate=0.5, use_series_oof=False):
super(Amodel, self).__init__()
self.use_series_oof = use_series_oof
self.input_series_block = nn.Sequential(
nn.Linear(series_dim, hidden_dim)
,nn.LayerNorm(hidden_dim)
)
self.input_feature_block = nn.Sequential(
nn.Linear(feature_dim, hidden_dim)
,nn.BatchNorm1d(hidden_dim)
,nn.LeakyReLU()
)
self.gru_series = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
self.hidden_feature_block = []
for h in range(hidden_num-1):
self.hidden_feature_block.extend([
nn.Linear(hidden_dim, hidden_dim)
,nn.BatchNorm1d(hidden_dim)
,nn.Dropout(drop_rate)
,nn.LeakyReLU()
])
self.hidden_feature_block = nn.Sequential(*self.hidden_feature_block)
self.output_block = nn.Sequential(
nn.Linear(3*hidden_dim if use_series_oof else 2*hidden_dim, 1*hidden_dim)
,nn.LeakyReLU()
,nn.Linear(1*hidden_dim, 1*hidden_dim)
,nn.LeakyReLU()
,nn.Linear(1*hidden_dim, target_num)
,nn.Sigmoid()
)
def batch_gru(self,series,mask):
node_num = mask.sum(dim=-1).detach().cpu()
pack = nn.utils.rnn.pack_padded_sequence(series, node_num, batch_first=True, enforce_sorted=False)
message,hidden = self.gru_series(pack)
pooling_feature = []
for i,n in enumerate(node_num.numpy()):
n = int(n)
bi = 0
si = message.unsorted_indices[i]
for k in range(n):
if k == n-1:
sample_feature = message.data[bi+si]
bi = bi + message.batch_sizes[k]
pooling_feature.append(sample_feature)
return torch.stack(pooling_feature,0)
def forward(self, data):
x1 = self.input_series_block(data['batch_series'])
x1 = self.batch_gru(x1,data['batch_mask'])
if self.use_series_oof:
x2 = self.input_feature_block(data['batch_feature'])
x2 = self.hidden_feature_block(x2)
x = torch.cat([x1,x2],axis=1)
y = self.output_block(x)
else:
y = self.output_block(x1)
return y