forked from VainF/Torch-Pruning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_customized_layer.py
84 lines (68 loc) · 3.02 KB
/
test_customized_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from typing import Sequence
############
# Customize your layer
#
class CustomizedLayer(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.in_dim = in_dim
self.scale = nn.Parameter(torch.Tensor(self.in_dim))
self.bias = nn.Parameter(torch.Tensor(self.in_dim))
def forward(self, x):
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()
x = torch.div(x, norm)
return x * self.scale + self.bias
def __repr__(self):
return "CustomizedLayer(in_dim=%d)"%(self.in_dim)
class FullyConnectedNet(nn.Module):
"""https://github.com/VainF/Torch-Pruning/issues/21"""
def __init__(self, input_size, num_classes, HIDDEN_UNITS):
super().__init__()
self.fc1 = nn.Linear(input_size, HIDDEN_UNITS)
self.customized_layer = CustomizedLayer(HIDDEN_UNITS)
self.fc2 = nn.Linear(HIDDEN_UNITS, num_classes)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.customized_layer(x)
y_hat = self.fc2(x)
return y_hat
############################
# Implement your pruning function for the customized layer
#
class MyPruningFn(tp.functional.structured.BasePruner):
def prune(self, layer: CustomizedLayer, idxs: Sequence[int]) -> nn.Module:
keep_idxs = list(set(range(layer.in_dim)) - set(idxs))
layer.in_dim = layer.in_dim-len(idxs)
layer.scale = torch.nn.Parameter(layer.scale.data.clone()[keep_idxs])
layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
return layer
@staticmethod
def calc_nparams_to_prune(layer: CustomizedLayer, idxs: Sequence[int]) -> int:
nparams_to_prune = len(idxs) * 2
return nparams_to_prune
my_pruning_fn = MyPruningFn()
model = FullyConnectedNet(128, 10, 256)
# pruning according to L1 Norm
strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()
DG = tp.DependencyGraph()
# Register your customized layer
DG.register_customized_layer(
CustomizedLayer,
in_ch_pruning_fn=my_pruning_fn, # A function to prune channels/dimensions of input tensor
out_ch_pruning_fn=my_pruning_fn, # A function to prune channels/dimensions of output tensor
get_in_ch_fn=lambda l: l.in_dim, # estimate the n_channel of layer input. Return None if the layer does not change tensor shape.
get_out_ch_fn=lambda l: l.in_dim) # estimate the n_channel of layer output. Return None if the layer does not change tensor shape.
# Build dependency graph
DG.build_dependency(model, example_inputs=torch.randn(1,128))
# get a pruning plan according to the dependency graph. idxs is the indices of pruned filters.
pruning_plan = DG.get_pruning_plan( model.fc1, tp.prune_linear_out_channel, idxs=strategy(model.fc1.weight, amount=0.4) )
print(pruning_plan)
# execute this plan (prune the model)
pruning_plan.exec()
print(model)