-
Notifications
You must be signed in to change notification settings - Fork 1
/
BaseNet.py
127 lines (107 loc) · 5.2 KB
/
BaseNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import copy
import numpy as np
from typing import Any, Callable, List, Dict, Tuple
from maraboupy import Marabou
from maraboupy.MarabouNetwork import MarabouNetwork
import tempfile
THRESHOLD = 10**-10
def get_activation(name: str, tensor_logger: Dict[str, Any],
detach: bool = True, is_lastlayer:bool = False)->Callable[..., None]:
if is_lastlayer:
def hook(model: torch.nn.Module, input: Tensor, output: Tensor)->None:
raw = torch.flatten(output, start_dim = 1, end_dim = -1).cpu().detach().numpy()
#use argmax instead of broadcasting just in case comparing floating point is finicky
mask = np.zeros(raw.shape, dtype = bool)
mask[np.arange(raw.shape[0]), raw.argmax(axis=1)] = 1
tensor_logger[name] = np.concatenate((tensor_logger[name], mask),
axis = 0) if name in tensor_logger else mask
return hook
if detach:
def hook(model: torch.nn.Module, input: Tensor, output: Tensor)->None:
raw = torch.flatten(
output, start_dim=1, end_dim=-1).cpu().detach().numpy()
raw = raw > 0
logging.debug("{}, {}".format(name,raw.shape))
tensor_logger[name] = np.concatenate((tensor_logger[name], raw),
axis = 0) if name in tensor_logger else raw
logging.debug(tensor_logger[name].shape)
return hook
else:
#keep the gradient, so cannot convert to bit here
def hook(model: torch.nn.Module, input: Tensor, output: Tensor):
raw = torch.sigmoid(torch.flatten(
output, start_dim=1, end_dim=-1))
logging.debug("{}, {}".format(name,raw.shape))
tensor_logger[name] = torch.cat([tensor_logger[name], raw],
dim = 0) if name in tensor_logger else raw
logging.debug(tensor_logger[name].shape)
return hook
class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()
self.tensor_log: Dict[str, Any] = {}
self.gradient_log = {}
self.hooks: List[Callable[..., None]] = []
self.bw_hooks = []
self.marabou_net: MarabouNetwork
def build_marabou_net(self, dummy_input: torch.Tensor)->MarabouNetwork:
"""
convert the network to MarabouNetwork
"""
tempf = tempfile.NamedTemporaryFile()
torch.onnx.export(self, dummy_input, tempf.name, verbose=False)
self.marabou_net = Marabou.read_onnx(tempf.name)
assert self.check_network_consistancy(), "Marabou network is not consistent with the target network!!!"
return self.marabou_net
def check_network_consistancy(self)->bool:
"""
check if the built marabou_net is actually equivalent to the original net
Strat: generate a random input, and run it through both network. The outputs should be similar, up
to a threshold
"""
if self.marabou_net is None:
return False
input_shape: Tuple[int] = self.marabou_net.inputVars[0].shape
dummy_inputs: List[torch.Tensor] = [torch.rand(input_shape)]
marabou_output: List[np.ndarray] = self.marabou_net.evaluateWithMarabou(inputValues=dummy_inputs)
internal_output: torch.Tensor = self.forward(torch.stack(dummy_inputs))
marabou_output_flat:torch.Tensor = torch.Tensor(marabou_output[0]).squeeze().flatten()
internal_output_flat:torch.Tensor = internal_output[0].squeeze().flatten()
for idx in range(len(marabou_output_flat)):
if abs(marabou_output_flat[idx] - internal_output_flat[idx]) > THRESHOLD:
logging.info("Built marabou network is NOT consistent\n Test outputs:{} != {}".format(marabou_output_flat, internal_output_flat))
return False
logging.info("Built marabou network is consistent. Test outputs:{} vs {}".format(marabou_output_flat, internal_output_flat))
return True
def reset_hooks(self):
self.tensor_log = {}
for h in self.hooks:
h.remove()
def reset_bw_hooks(self):
self.input_labels = None
self.gradient_log = {}
for h in self.bw_hooks:
h.remove()
def register_log(self, detach: bool)->None:
raise NotImplementedError
def register_gradient(self, detach: bool)->None:
raise NotImplementedError
def model_savename(self)->str:
raise NotImplementedError
def get_pattern(self, input: Tensor,
layers: List[str],
device: torch.device,
detach: bool = True,
flatten:bool = True)->Dict[str, np.ndarray]:
self.eval()
self.register_log(detach)
self.forward(input.to(device))
tensor_log = copy.deepcopy(self.tensor_log)
if flatten:
return {'all': np.concatenate([tensor_log[l] for l in layers], axis=1)}
return tensor_log