Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resnet50剪枝报错 #24

Open
Wq-dd opened this issue Aug 23, 2023 · 1 comment
Open

resnet50剪枝报错 #24

Wq-dd opened this issue Aug 23, 2023 · 1 comment
Labels
question Further information is requested

Comments

@Wq-dd
Copy link

Wq-dd commented Aug 23, 2023

你好,我在使用resnet18为主干网的retinanet时,自己使用稀疏训练后的模型剪枝会报错,我的做法是:

  1. 首先将训练好的模型计算bn的阈值得到每个bn层应该要剪枝的索引,并保存到一个dict里。
  2. 然后循环1中的dict使用torchprunner去剪枝,会遇到前面的某些层如果剪了过多通道,后面层再剪时会出现索引越界。
  3. 下面是我的部分代码。
        import torchpruner 
        # 创建ONNXGraph对象,绑定需要被剪枝的模型
        self.model.eval()
        graph = torchpruner.ONNXGraph(self.model.cpu())
        ##build ONNX静态图结构,需要指定输入的张量
        graph.build_graph(inputs=(torch.zeros(1, 3, 640, 640),))
        for i, (k, v) in enumerate(mask_dict_for_pruner.items()):
        # 获取conv1模块对应的module
            conv1_module = graph.modules[k]

            # 对前四个通道进行剪枝分析,指定对weight权重进行剪枝,剪枝前四个通道
            # weight权重out_channels对应的通道维度为0
            result = conv1_module.cut_analysis(attribute_name="weight", index=v, dim=0)

            # 剪枝执行模块执行剪枝操作,对模型完成剪枝过程.context变量提供了用于剪枝恢复的上下文
            self.model, context = torchpruner.set_cut(self.model, result)
        # 新的model即为剪枝后的模型
        print(self.model)```

请问是我的用法不对吗还是说这种先计算剪枝的索引再调用torchpruner的方法不对呢
@gdh1995
Copy link
Collaborator

gdh1995 commented Aug 27, 2023

每次剪枝后,model 对象变了,就都要重建 graph、重新执行 build_graph

@gdh1995 gdh1995 added the question Further information is requested label Aug 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants