Skip to content

Commit

Permalink
Bump min torch version (#1641)
Browse files Browse the repository at this point in the history
* Update min torch version

* add sentencepiece

* remove install from source
  • Loading branch information
echarlaix committed Jan 19, 2024
1 parent 6b88276 commit d6fe4e3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 56 deletions.
104 changes: 49 additions & 55 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import is_torch_less_than_1_11

if is_diffusers_available():
from diffusers import ModelMixin
Expand Down Expand Up @@ -556,62 +555,57 @@ def remap(value):

dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)

# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
raise RuntimeError("The ONNX export using the PyTorch framework is only supported for v1.11+")
else:
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
check_dummy_inputs_are_allowed(model, dummy_inputs)

inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dict(chain(inputs.items(), config.outputs.items())),
do_constant_folding=True,
opset_version=opset,
)
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
check_dummy_inputs_are_allowed(model, dummy_inputs)

inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dict(chain(inputs.items(), config.outputs.items())),
do_constant_folding=True,
opset_version=opset,
)

# check if external data was exported
# TODO: this is quite inefficient as we load in memory if models are <2GB without external data
onnx_model = onnx.load(str(output), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)

if model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA:
tensors_paths = _get_onnx_external_data_tensors(onnx_model)
logger.info("Saving external data to one file...")

# try free model memory
del model
del onnx_model
gc.collect()
if device.type == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()

onnx_model = onnx.load(
str(output), load_external_data=True
) # this will probably be too memory heavy for large models
onnx.save(
onnx_model,
str(output),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=output.name + "_data",
size_threshold=1024 if not FORCE_ONNX_EXTERNAL_DATA else 0,
)
# check if external data was exported
# TODO: this is quite inefficient as we load in memory if models are <2GB without external data
onnx_model = onnx.load(str(output), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)

if model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA:
tensors_paths = _get_onnx_external_data_tensors(onnx_model)
logger.info("Saving external data to one file...")

# try free model memory
del model
del onnx_model
gc.collect()
if device.type == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()

onnx_model = onnx.load(
str(output), load_external_data=True
) # this will probably be too memory heavy for large models
onnx.save(
onnx_model,
str(output),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=output.name + "_data",
size_threshold=1024 if not FORCE_ONNX_EXTERNAL_DATA else 0,
)

# delete previous external data
for tensor in tensors_paths:
os.remove(output.parent / tensor)
# delete previous external data
for tensor in tensors_paths:
os.remove(output.parent / tensor)

return input_names, output_names

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"coloredlogs",
"sympy",
"transformers[sentencepiece]>=4.26.0",
"torch>=1.9",
"torch>=1.11",
"packaging",
"numpy",
"huggingface_hub>=0.8.0",
Expand Down

0 comments on commit d6fe4e3

Please sign in to comment.