-
Notifications
You must be signed in to change notification settings - Fork 143
/
train.py
112 lines (72 loc) · 2.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from util.time import *
from util.env import *
from sklearn.metrics import mean_squared_error
from test import *
import torch.nn.functional as F
import numpy as np
from evaluate import get_best_performance_data, get_val_performance_data, get_full_err_scores
from sklearn.metrics import precision_score, recall_score, roc_auc_score, f1_score
from torch.utils.data import DataLoader, random_split, Subset
from scipy.stats import iqr
def loss_func(y_pred, y_true):
loss = F.mse_loss(y_pred, y_true, reduction='mean')
return loss
def train(model = None, save_path = '', config={}, train_dataloader=None, val_dataloader=None, feature_map={}, test_dataloader=None, test_dataset=None, dataset_name='swat', train_dataset=None):
seed = config['seed']
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=config['decay'])
now = time.time()
train_loss_list = []
cmp_loss_list = []
device = get_device()
acu_loss = 0
min_loss = 1e+8
min_f1 = 0
min_pre = 0
best_prec = 0
i = 0
epoch = config['epoch']
early_stop_win = 15
model.train()
log_interval = 1000
stop_improve_count = 0
dataloader = train_dataloader
for i_epoch in range(epoch):
acu_loss = 0
model.train()
for x, labels, attack_labels, edge_index in dataloader:
_start = time.time()
x, labels, edge_index = [item.float().to(device) for item in [x, labels, edge_index]]
optimizer.zero_grad()
out = model(x, edge_index).float().to(device)
loss = loss_func(out, labels)
loss.backward()
optimizer.step()
train_loss_list.append(loss.item())
acu_loss += loss.item()
i += 1
# each epoch
print('epoch ({} / {}) (Loss:{:.8f}, ACU_loss:{:.8f})'.format(
i_epoch, epoch,
acu_loss/len(dataloader), acu_loss), flush=True
)
# use val dataset to judge
if val_dataloader is not None:
val_loss, val_result = test(model, val_dataloader)
if val_loss < min_loss:
torch.save(model.state_dict(), save_path)
min_loss = val_loss
stop_improve_count = 0
else:
stop_improve_count += 1
if stop_improve_count >= early_stop_win:
break
else:
if acu_loss < min_loss :
torch.save(model.state_dict(), save_path)
min_loss = acu_loss
return train_loss_list