Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORT training support stage3 #1439

Closed
wants to merge 7 commits into from
Closed

ORT training support stage3 #1439

wants to merge 7 commits into from

Conversation

pengwa
Copy link

@pengwa pengwa commented Oct 8, 2023

For new version ORT training, we had stage3 support

This PR enable that support when stage3 is used.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@pengwa
Copy link
Author

pengwa commented Oct 10, 2023

@JingyaHuang could you please help take a look? Thanks a lot!

@JingyaHuang
Copy link
Contributor

Hi @pengwa, thanks a lot for the PR and my apology for the late reply. We just did some refactoring (#1335) for the ORTTrainer class and we would like finish that before merging any other new features.

Could you rebase the branch? Thx!

@pengwa
Copy link
Author

pengwa commented Oct 21, 2023

Hi @pengwa, thanks a lot for the PR and my apology for the late reply. We just did some refactoring (#1335) for the ORTTrainer class and we would like finish that before merging any other new features.

Could you rebase the branch? Thx!

Thanks Jiangya, updated.

Copy link
Contributor

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the PR, please fix the style as well with make style.

And a question for the zero3 support, is it experimental and only available after the 1.17 release, would it be possible to test already? And any idea when ORT 1.17 will be released? Thanks!

@pengwa
Copy link
Author

pengwa commented Oct 25, 2023

Thanks for updating the PR, please fix the style as well with make style.

And a question for the zero3 support, is it experimental and only available after the 1.17 release, would it be possible to test already? And any idea when ORT 1.17 will be released? Thanks!

Yeah, I will fix later.

As of the 1.17, we will have in the coming weeks, I am not sure the concrete date, but we are striving to ship stage3 in this new release. Currently the feature is baking in main branch (which already has version 1.17). Hope this make things clear. Feel free to let me know if there are more questions.

@JingyaHuang
Copy link
Contributor

Hi @pengwa, thanks for updating the branch and for the explanation. I did a quick test (training bert on glue) with Zero3 enabled, the training did not failed but it poped up some error logs related to the tracing, have you seen that before?

(FYI, I tested with ort nightly)

    ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
2023-10-26 09:46:58,557 orttraining [WARNING] - Fallback to PyTorch due to exception <class 'onnxruntime.training.ortmodule._fallback_exceptions.ORTModuleONNXModelException'> was triggered. Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. See details below:

RuntimeError: There was an error while exporting the PyTorch model to ONNX: 

Traceback (most recent call last):
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_utils.py", line 325, in get_exception_as_string
    raise exception
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 397, in _get_exported_model
    torch.onnx.export(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_io.py", line 388, in forward
    return _transform_output_to_flat_tuple(self._original_module(*new_args, **new_kwargs))
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1573, in forward
    outputs = self.bert(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1015, in forward
    embedding_output = self.embeddings(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 232, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1517, in _call_impl
    result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_subscriber_manager.py", line 245, in _pre_forward_module_with_kwargs_hook
    module_inputs, kwargs = sub.pre_forward_module_apply(self._run_ctx, module, module_inputs, kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_subscriber_base.py", line 85, in pre_forward_module_apply
    updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_rtx, module, args, kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_zero_offload_subscriber.py", line 528, in pre_forward_module_apply_impl
    rets = ORTZeROOffloadPreForwardFunction.apply(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: _Map_base::at


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_utils.py", line 325, in get_exception_as_string
    raise exception
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 255, in forward
    build_gradient_graph = self._export_model(*inputs, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_logger.py", line 158, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_logger.py", line 273, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 294, in _export_model
    self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 405, in _get_exported_model
    raise wrap_exception(  # noqa: B904
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_fallback_exceptions.py", line 74, in wrap_exception
    raise new_exception(raised_exception) from raised_exception
onnxruntime.training.ortmodule._fallback_exceptions.ORTModuleONNXModelException: There was an error while exporting the PyTorch model to ONNX: 

Traceback (most recent call last):
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_utils.py", line 325, in get_exception_as_string
    raise exception
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 397, in _get_exported_model
    torch.onnx.export(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_io.py", line 388, in forward
    return _transform_output_to_flat_tuple(self._original_module(*new_args, **new_kwargs))
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1573, in forward
    outputs = self.bert(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1015, in forward
    embedding_output = self.embeddings(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 232, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1517, in _call_impl
    result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_subscriber_manager.py", line 245, in _pre_forward_module_with_kwargs_hook
    module_inputs, kwargs = sub.pre_forward_module_apply(self._run_ctx, module, module_inputs, kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_subscriber_base.py", line 85, in pre_forward_module_apply
    updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_rtx, module, args, kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_zero_offload_subscriber.py", line 528, in pre_forward_module_apply_impl
    rets = ORTZeROOffloadPreForwardFunction.apply(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: _Map_base::at


  4%|████▊                                                                                                                  | 1/25 [00:01<00:28,  1.21s/it]Invalidate trace cache @ step 231: expected module 1, but got module 230
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:19<00:00,  1.41it/s]

Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 19.1384, 'train_samples_per_second': 10.45, 'train_steps_per_second': 1.306, 'train_loss': 0.7029191589355469, 'epoch': 1.0}             
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:19<00:00,  1.31it/s]
Training metrics(ORT):
 {'train_runtime': 19.1384, 'train_samples_per_second': 10.45, 'train_steps_per_second': 1.306, 'train_loss': 0.7029191589355469, 'epoch': 1.0}
Invalidate trace cache @ step 0 and module 0: cache has only 0 modules
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  6.81it/s]
Evaluation metrics:
 {'eval_loss': 0.6702812314033508, 'eval_accuracy': 0.62, 'eval_runtime': 1.203, 'eval_samples_per_second': 41.563, 'eval_steps_per_second': 5.819, 'epoch': 1.0}
.
----------------------------------------------------------------------

@pengwa
Copy link
Author

pengwa commented Oct 26, 2023

Hi @pengwa, thanks for updating the branch and for the explanation. I did a quick test (training bert on glue) with Zero3 enabled, the training did not failed but it poped up some error logs related to the tracing, have you seen that before?

(FYI, I tested with ort nightly)

    ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
2023-10-26 09:46:58,557 orttraining [WARNING] - Fallback to PyTorch due to exception <class 'onnxruntime.training.ortmodule._fallback_exceptions.ORTModuleONNXModelException'> was triggered. Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. See details below:

RuntimeError: There was an error while exporting the PyTorch model to ONNX: 

Traceback (most recent call last):
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_utils.py", line 325, in get_exception_as_string
    raise exception
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 397, in _get_exported_model
    torch.onnx.export(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_io.py", line 388, in forward
    return _transform_output_to_flat_tuple(self._original_module(*new_args, **new_kwargs))
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1573, in forward
    outputs = self.bert(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1015, in forward
    embedding_output = self.embeddings(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 232, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1517, in _call_impl
    result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_subscriber_manager.py", line 245, in _pre_forward_module_with_kwargs_hook
    module_inputs, kwargs = sub.pre_forward_module_apply(self._run_ctx, module, module_inputs, kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_subscriber_base.py", line 85, in pre_forward_module_apply
    updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_rtx, module, args, kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_zero_offload_subscriber.py", line 528, in pre_forward_module_apply_impl
    rets = ORTZeROOffloadPreForwardFunction.apply(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: _Map_base::at


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_utils.py", line 325, in get_exception_as_string
    raise exception
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 255, in forward
    build_gradient_graph = self._export_model(*inputs, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_logger.py", line 158, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_logger.py", line 273, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 294, in _export_model
    self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 405, in _get_exported_model
    raise wrap_exception(  # noqa: B904
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_fallback_exceptions.py", line 74, in wrap_exception
    raise new_exception(raised_exception) from raised_exception
onnxruntime.training.ortmodule._fallback_exceptions.ORTModuleONNXModelException: There was an error while exporting the PyTorch model to ONNX: 

Traceback (most recent call last):
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_utils.py", line 325, in get_exception_as_string
    raise exception
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 397, in _get_exported_model
    torch.onnx.export(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/ortmodule/_io.py", line 388, in forward
    return _transform_output_to_flat_tuple(self._original_module(*new_args, **new_kwargs))
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1573, in forward
    outputs = self.bert(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1015, in forward
    embedding_output = self.embeddings(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 232, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1517, in _call_impl
    result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_subscriber_manager.py", line 245, in _pre_forward_module_with_kwargs_hook
    module_inputs, kwargs = sub.pre_forward_module_apply(self._run_ctx, module, module_inputs, kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_subscriber_base.py", line 85, in pre_forward_module_apply
    updated_args, updated_kwargs = self.pre_forward_module_apply_impl(run_rtx, module, args, kwargs)
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/onnxruntime/training/utils/hooks/_zero_offload_subscriber.py", line 528, in pre_forward_module_apply_impl
    rets = ORTZeROOffloadPreForwardFunction.apply(
  File "/home/onnxruntimedev/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: _Map_base::at


  4%|████▊                                                                                                                  | 1/25 [00:01<00:28,  1.21s/it]Invalidate trace cache @ step 231: expected module 1, but got module 230
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:19<00:00,  1.41it/s]

Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 19.1384, 'train_samples_per_second': 10.45, 'train_steps_per_second': 1.306, 'train_loss': 0.7029191589355469, 'epoch': 1.0}             
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:19<00:00,  1.31it/s]
Training metrics(ORT):
 {'train_runtime': 19.1384, 'train_samples_per_second': 10.45, 'train_steps_per_second': 1.306, 'train_loss': 0.7029191589355469, 'epoch': 1.0}
Invalidate trace cache @ step 0 and module 0: cache has only 0 modules
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  6.81it/s]
Evaluation metrics:
 {'eval_loss': 0.6702812314033508, 'eval_accuracy': 0.62, 'eval_runtime': 1.203, 'eval_samples_per_second': 41.563, 'eval_steps_per_second': 5.819, 'epoch': 1.0}
.
----------------------------------------------------------------------

Thank @JingyaHuang for bringing this up! I believe this is related to exporter issue in some PyTorch versions. Well, indeed we need more version checks on the dependency libs including PyTorch, DeepSpeed, and ORT. Let me collect the versions first, and update the check later.

@pengwa
Copy link
Author

pengwa commented Nov 7, 2023

Hi @JingyaHuang, sorry for the delay, I am focusing on optimization the perf and mem, did not spare more time on more restrict version controls, I will hand over this work to my teammate. :) Let me close this one, we will create new PR later.

@pengwa pengwa closed this Nov 7, 2023
@JingyaHuang
Copy link
Contributor

@pengwa Sure, no probs!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants