diff --git a/VERSION b/VERSION index 7172442..37f868f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.40 +0.1.41 diff --git a/onnxslim/core/__init__.py b/onnxslim/core/__init__.py index 34dcfda..a24d0a6 100644 --- a/onnxslim/core/__init__.py +++ b/onnxslim/core/__init__.py @@ -9,6 +9,7 @@ import onnxslim.third_party.onnx_graphsurgeon as gs from onnxslim.core.optimization import optimize_model from onnxslim.core.utils import delete_node +from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant from onnxslim.third_party.symbolic_shape_infer import SymbolicShapeInference from onnxslim.utils import save @@ -173,6 +174,16 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto: inp_dtype = [input.dtype for input in node.inputs][0] if inp_dtype in [np.float16, np.float32]: delete_node(node) + else: + outp_dtype = [output.dtype for output in node.outputs][0] + if outp_dtype == np.float16: + node.attrs["to"] = dtype_to_onnx(np.float32) + node.outputs[0].dtype = np.float32 + elif node.op == "ConstantOfShape": + if hasattr(node, "attrs") and "value" in node.attrs: + if node.attrs["value"].dtype == np.float16: + node.attrs["value"].values = node.attrs["value"].values.astype(np.float32) + node.outputs[0].dtype = np.float32 for tensor in graph.tensors().values(): if isinstance(tensor, gs.Variable) and tensor.dtype == np.float16: