Skip to content

Commit

Permalink
BF16 support in the ONNX export (#1654)
Browse files Browse the repository at this point in the history
export work, ort does not
  • Loading branch information
fxmarty authored Jan 26, 2024
1 parent 843d3f4 commit 45c1c09
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 8 deletions.
8 changes: 8 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def parse_args_onnx(parser):
action="store_true",
help="Use half precision during the export. PyTorch-only, requires `--device cuda`.",
)
optional_group.add_argument(
"--dtype",
type=str,
default=None,
choices=["fp32", "fp16", "bf16"],
help="The floating point precision to use for the export. Supported options: fp32 (float32), fp16 (float16), bf16 (bfloat16).",
)
optional_group.add_argument(
"--optimize",
type=str,
Expand Down Expand Up @@ -253,6 +260,7 @@ def run(self):
opset=self.args.opset,
device=self.args.device,
fp16=self.args.fp16,
dtype=self.args.dtype,
optimize=self.args.optimize,
monolith=self.args.monolith,
no_post_process=self.args.no_post_process,
Expand Down
47 changes: 42 additions & 5 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def main_export(
task: str = "auto",
opset: Optional[int] = None,
device: str = "cpu",
dtype: Optional[str] = None,
fp16: Optional[bool] = False,
optimize: Optional[str] = None,
monolith: bool = False,
Expand Down Expand Up @@ -216,6 +217,8 @@ def main_export(
The device to use to do the export. Defaults to "cpu".
fp16 (`Optional[bool]`, defaults to `"False"`):
Use half precision during the export. PyTorch-only, requires `device="cuda"`.
dtype (`Optional[str]`, defaults to `None`):
The floating point precision to use for the export. Supported options: `"fp32"` (float32), `"fp16"` (float16), `"bf16"` (bfloat16). Defaults to `"fp32"`.
optimize (`Optional[str]`, defaults to `None`):
Allows to run ONNX Runtime optimizations directly during the export. Some of these optimizations are specific to
ONNX Runtime, and the resulting ONNX will not be usable with other runtime as OpenVINO or TensorRT.
Expand Down Expand Up @@ -283,16 +286,31 @@ def main_export(
>>> main_export("gpt2", output="gpt2_onnx/")
```
"""

if fp16:
if dtype is not None:
raise ValueError(
f'Both the arguments `fp16` ({fp16}) and `dtype` ({dtype}) were specified in the ONNX export, which is not supported. Please specify only `dtype`. Possible options: "fp32" (default), "fp16", "bf16".'
)

logger.warning(
'The argument `fp16` is deprecated in the ONNX export. Please use the argument `dtype="fp16"` instead, or `--dtype fp16` from the command-line.'
)

dtype = "fp16"
elif dtype is None:
dtype = "fp32" # Defaults to float32.

if optimize == "O4" and device != "cuda":
raise ValueError(
"Requested O4 optimization, but this optimization requires to do the export on GPU."
" Please pass the argument `--device cuda`."
)

if (framework == "tf" and fp16 is True) or not is_torch_available():
if (framework == "tf" and fp16) or not is_torch_available():
raise ValueError("The --fp16 option is supported only for PyTorch.")

if fp16 and device == "cpu":
if dtype == "fp16" and device == "cpu":
raise ValueError(
"FP16 export is supported only when exporting on GPU. Please pass the option `--device cuda`."
)
Expand All @@ -311,7 +329,13 @@ def main_export(
library_name = TasksManager.infer_library_from_model(
model_name_or_path, subfolder=subfolder, library_name=library_name
)
torch_dtype = None if fp16 is False else torch.float16

torch_dtype = None
if framework == "pt":
if dtype == "fp16":
torch_dtype = torch.float16
elif dtype == "bf16":
torch_dtype = torch.bfloat16

if task.endswith("-with-past") and monolith:
task_non_past = task.replace("-with-past", "")
Expand Down Expand Up @@ -479,8 +503,16 @@ def onnx_export(
):
library_name = TasksManager._infer_library_from_model(model)
framework = "pt" if is_torch_available() and isinstance(model, torch.nn.Module) else "tf"

dtype = get_parameter_dtype(model) if framework == "pt" else model.dtype
float_dtype = "fp16" if "float16" in str(dtype) else "fp32"

if "bfloat16" in str(dtype):
float_dtype = "bf16"
elif "float16" in str(dtype):
float_dtype = "fp16"
else:
float_dtype = "fp32"

model_type = "stable-diffusion" if library_name == "diffusers" else model.config.model_type.replace("_", "-")
custom_architecture = library_name == "transformers" and model_type not in TasksManager._SUPPORTED_MODEL_TYPE
task = TasksManager.map_from_synonym(task)
Expand Down Expand Up @@ -615,14 +647,19 @@ def onnx_export(

model.save_config(output)

if float_dtype == "bf16":
logger.warning(
f"Exporting the model {model.__class__.__name__} in bfloat16 float dtype. After the export, ONNX Runtime InferenceSession with CPU/CUDA execution provider likely does not implement all operators for the bfloat16 data type, and the loading is likely to fail."
)

_, onnx_outputs = export_models(
models_and_onnx_configs=models_and_onnx_configs,
opset=opset,
output_dir=output,
output_names=onnx_files_subpaths,
input_shapes=input_shapes,
device=device,
dtype="fp16" if float_dtype == "fp16" else None,
dtype=float_dtype,
no_dynamic_axes=no_dynamic_axes,
model_kwargs=model_kwargs,
)
Expand Down
4 changes: 3 additions & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def __init__(

self.use_fp16 = False
for inp in model.get_inputs():
if (inp.name == "past_key_values" or inp.name in self.key_value_input_names) and inp.type == "tensor(float16)":
if (
inp.name == "past_key_values" or inp.name in self.key_value_input_names
) and inp.type == "tensor(float16)":
self.use_fp16 = True
break

Expand Down
1 change: 0 additions & 1 deletion optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""ORTModelForXXX classes, allowing to run ONNX Models with ONNX Runtime using the same API as Transformers."""

import logging
import math
import re
import shutil
from pathlib import Path
Expand Down
2 changes: 1 addition & 1 deletion tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3378,7 +3378,7 @@ def test_compare_to_io_binding(self, model_arch):
self.assertIsInstance(io_outputs.logits, torch.Tensor)

# compare tensor outputs
self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), io_outputs.logits, atol=1e-1))
self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), io_outputs.logits, atol=1e-1))

gc.collect()

Expand Down

0 comments on commit 45c1c09

Please sign in to comment.