Skip to content

Commit

Permalink
fix: not use an out-of-date ONNXGraph
Browse files Browse the repository at this point in the history
This is for
#7 (comment)
  • Loading branch information
gdh1995 committed Dec 7, 2021
1 parent 20005c4 commit 4aacd11
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion examples/torchpruner/prune_by_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))

# 遍历所有的Module
for key in graph.modules:
for key in list(graph.modules):
module = graph.modules[key]
# 如果该module对应了BN层
if isinstance(module.nn_object, torch.nn.BatchNorm2d):
Expand All @@ -28,6 +28,10 @@
index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.2)]
result = module.cut_analysis("weight", index=index, dim=0)
model, context = torchpruner.set_cut(model, result)
if context:
# graph 存放了各层参数和输出张量的 numpy.ndarray 版本,需要更新
graph = torchpruner.ONNXGraph(model) # 也可以不重新创建 graph
graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))

# 新的model即为剪枝后的模型
print(model)

0 comments on commit 4aacd11

Please sign in to comment.