diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index f9c8b87..e5c7f78 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -172,9 +172,15 @@ def dead_node_elimination(graph): logger.debug(f"removing Add op: {node.name}") elif node.op == "Expand": # tests/test_onnx_nets.py::TestTimmClass::test_timm[lambda_resnet26rpt_256] - if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant) and np.all(node.inputs[1].values == 1): - delete_node(node) - logger.debug(f"removing Expand op: {node.name}") + if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant): + constant_variable = node.inputs[1] + value = constant_variable.values + if value.ndim == 0 and value == 0: + delete_node(node) + logger.debug(f"removing Expand op: {node.name}") + elif np.all(value == 0) and (node.inputs[0].shape == value.shape): + delete_node(node) + logger.debug(f"removing Expand op: {node.name}") elif node.op == "Concat": if len(node.inputs) == 1: delete_node(node)