Skip to content

Commit

Permalink
Fix float32 conversion bug (#48)
Browse files Browse the repository at this point in the history
* fix float32 conversion bug

* [Release] 0.1.41

* Auto-format by https://ultralytics.com/actions

---------

Co-authored-by: UltralyticsAssistant <[email protected]>
  • Loading branch information
inisis and UltralyticsAssistant authored Nov 24, 2024
1 parent 13da157 commit 3049e6b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.40
0.1.41
11 changes: 11 additions & 0 deletions onnxslim/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.optimization import optimize_model
from onnxslim.core.utils import delete_node
from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant
from onnxslim.third_party.symbolic_shape_infer import SymbolicShapeInference
from onnxslim.utils import save
Expand Down Expand Up @@ -173,6 +174,16 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
inp_dtype = [input.dtype for input in node.inputs][0]
if inp_dtype in [np.float16, np.float32]:
delete_node(node)
else:
outp_dtype = [output.dtype for output in node.outputs][0]
if outp_dtype == np.float16:
node.attrs["to"] = dtype_to_onnx(np.float32)
node.outputs[0].dtype = np.float32
elif node.op == "ConstantOfShape":
if hasattr(node, "attrs") and "value" in node.attrs:
if node.attrs["value"].dtype == np.float16:
node.attrs["value"].values = node.attrs["value"].values.astype(np.float32)
node.outputs[0].dtype = np.float32

for tensor in graph.tensors().values():
if isinstance(tensor, gs.Variable) and tensor.dtype == np.float16:
Expand Down

0 comments on commit 3049e6b

Please sign in to comment.