Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export in ONNX/FP16 of PY007/TinyLlama-1.1B-Chat-v0.2 fails #1399

Closed
claeyzre opened this issue Sep 20, 2023 · 9 comments · Fixed by #1432
Closed

Export in ONNX/FP16 of PY007/TinyLlama-1.1B-Chat-v0.2 fails #1399

claeyzre opened this issue Sep 20, 2023 · 9 comments · Fixed by #1432
Labels
bug Something isn't working

Comments

@claeyzre
Copy link
Contributor

claeyzre commented Sep 20, 2023

System Info

Python: v3.9.13

Optimum: v1.13.1
Transformers: v4.33.2
Onnxruntime: v1.15.1
CUDA: v11.7

System: Windows 11

Who can help?

@fxmarty ?

Reproduction (minimal, reproducible, runnable)

While running this command

optimum-cli export onnx --model PY007/TinyLlama-1.1B-Chat-v0.2  --task causal-lm-with-past --fp16 --for-ort --device cuda tiny-llamav0.2-onnx

This model is a "Tiny" Llama-2.

I got this stack, I assume it's during the post-processing step:

(venv) PS D:\john.doe\optimum-test> optimum-cli export onnx --model PY007/TinyLlama-1.1B-Chat-v0.2  --task causal-lm-with-past --fp16 --for-ort --device cuda tiny-llamav0.2-onnx
The option --for-ort was passed, but its behavior is now the default in the ONNX exporter and passing it is not required anymore.
Framework not specified. Using pt to export to ONNX.
Downloading (…)lve/main/config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 652/652 [00:00<?, ?B/s]
D:\john.doe\optimum-test\venv\lib\site-packages\huggingface_hub\file_download.py:137: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in D:\cache\huggingface\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
  warnings.warn(message)
Downloading model.safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.40G/4.40G [01:50<00:00, 39.8MB/s]
Downloading (…)neration_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63.0/63.0 [00:00<?, ?B/s]
Downloading (…)okenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 762/762 [00:00<00:00, 48.8kB/s]
Downloading tokenizer.model: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500k/500k [00:00<00:00, 1.33MB/s]
Downloading (…)/main/tokenizer.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.84M/1.84M [00:00<00:00, 4.37MB/s]
Downloading (…)in/added_tokens.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 69.0/69.0 [00:00<?, ?B/s]
Downloading (…)cial_tokens_map.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 438/438 [00:00<00:00, 28.0kB/s]
Using the export variant default. Available variants are:
        - default: The default ONNX variant.
use_past = False is different than use_present_in_outputs = True, the value of use_present_in_outputs value will be used for the outputs.
Using framework PyTorch: 2.0.1+cu117
Overriding 1 configuration item(s)
        - use_cache -> True
D:\john.doe\optimum-test\venv\lib\site-packages\transformers\models\llama\modeling_llama.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if input_shape[-1] > 1:
D:\john.doe\optimum-test\venv\lib\site-packages\transformers\models\llama\modeling_llama.py:119: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if seq_len > self.max_seq_len_cached:
D:\john.doe\optimum-test\venv\lib\site-packages\transformers\models\llama\modeling_llama.py:348: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
D:\john.doe\optimum-test\venv\lib\site-packages\transformers\models\llama\modeling_llama.py:355: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
D:\john.doe\optimum-test\venv\lib\site-packages\transformers\models\llama\modeling_llama.py:365: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Saving external data to one file...
2 0 2 3 - 0 9 - 2 0   1 6 : 1 6 : 0 8 . 5 7 5 8 5 4 6   [ W : o n n x r u n t i m e : ,   s e s s i o n _ s t a t e . c c : 1 1 6 9   o n n x r u n t i m e : : V e r i f y E a c h N o d e I s A s s i g n e d T  A n E p ]   S o m e   n o d e s   w e r e   n o t   a s s i g n e d   t o   t h e   p r e f e r r e d   e x e c u t i o n   p r o v i d e r s   w h i c h   m a y   o r   m a y   n o t   h a v e   a n   n e g a  i v e   i m p a c t   o n   p e r f o r m a n c e .   e . g .   O R T   e x p l i c i t l y   a s s i g n s   s h a p e   r e l a t e d   o p s   t o   C P U   t o   i m p r o v e   p e r f .
 2 0 2 3 - 0 9 - 2 0   1 6 : 1 6 : 0 8 . 5 9 4 9 5 7 3   [ W : o n n x r u n t i m e : ,   s e s s i o n _ s t a t e . c c : 1 1 7 1   o n n x r u n t i m e : : V e r i f y E a c h N o d e I s A s s i g n e d T o A n E p ]   R e r u n n i n g   w i t h   v e r b o s e   o u t p u t   o n   a   n o n - m i n i m a l   b u i l d   w i l l   s h o w   n o d e   a s s i g n m e n t s .
 Using framework PyTorch: 2.0.1+cu117
Overriding 1 configuration item(s)
        - use_cache -> True
Asked a sequence length of 16, but a sequence length of 1 will be used with use_past == True for `input_ids`.
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
  File "C:\Program Files\Python39\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Program Files\Python39\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "D:\john.doe\optimum-test\venv\Scripts\optimum-cli.exe\__main__.py", line 7, in <module>
  File "D:\john.doe\optimum-test\venv\lib\site-packages\optimum\commands\optimum_cli.py", line 163, in main
    service.run()
  File "D:\john.doe\optimum-test\venv\lib\site-packages\optimum\commands\export\onnx.py", line 232, in run
    main_export(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\optimum\exporters\onnx\__main__.py", line 486, in main_export
    _, onnx_outputs = export_models(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\optimum\exporters\onnx\convert.py", line 752, in export_models
    export(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\optimum\exporters\onnx\convert.py", line 855, in export
    export_output = export_pytorch(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\optimum\exporters\onnx\convert.py", line 572, in export_pytorch
    onnx_export(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\onnx\utils.py", line 506, in export
    _export(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\onnx\utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\onnx\utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\onnx\utils.py", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\onnx\utils.py", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\jit\_trace.py", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\jit\_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\jit\_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\nn\modules\module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\optimum\exporters\onnx\model_patcher.py", line 113, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\transformers\models\llama\modeling_llama.py", line 820, in forward
    outputs = self.model(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\nn\modules\module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\transformers\models\llama\modeling_llama.py", line 708, in forward
    layer_outputs = decoder_layer(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\nn\modules\module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\transformers\models\llama\modeling_llama.py", line 424, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\torch\nn\modules\module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "D:\john.doe\optimum-test\venv\lib\site-packages\transformers\models\llama\modeling_llama.py", line 337, in forward
    key_states = torch.cat([past_key_value[0], key_states], dim=2)
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 4 for tensor number 1 in the list.

Expected behavior

The export should occur without errors.

@claeyzre claeyzre added the bug Something isn't working label Sep 20, 2023
@natke
Copy link

natke commented Oct 4, 2023

I am hitting this too

@fxmarty
Copy link
Contributor

fxmarty commented Oct 5, 2023

Thank you for the report, looking at it shortly!

@fxmarty
Copy link
Contributor

fxmarty commented Oct 5, 2023

Hi @claeyzre @natke, the issue is fixed in #1432

@natke
Copy link

natke commented Oct 6, 2023

I can now do the export, thanks! But I get a similar error when I try to run the graph

from transformers import AutoConfig, AutoTokenizer
from optimum.onnxruntime import ORTModelForCausalLM

print("Tokenizing ...")
tokenizer = AutoTokenizer.from_pretrained("PY007/TinyLlama-1.1B-step-50K-105b")

print("Generating model ...")
model = ORTModelForCausalLM.from_pretrained("PY007/TinyLlama-1.1B-step-50K-105b", export=True, cache_dir="__cache_dir")

print("Saving model ...")
model.save_pretrained("PY007/TinyLlama-1.1B-step-50K-105b".split("/")[-1] + "-onnx", cache_dir="__cache_dir")

print("Running generate ...")
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")

# Generate
generate_ids = model.generate(inputs.input_ids, max_length=50, num_beams=5, repetition_penalty=2.0, num_return_sequences=1)
print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

This is the error:

Running generate ...
2023-10-06 19:24:15.986215300 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Add node. Name:'/model/layers.0/self_attn/Add_1' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:171 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, int, const onnxruntime::TensorShape*, OrtValue*&, const onnxruntime::Node&) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{5,32,13,64} Requested shape:{5,4,13,64}

Traceback (most recent call last):
  File "/home/natke/Develop/samples/llama/run_llama_opt_ort.py", line 18, in <module>
    generate_ids = model.generate(inputs.input_ids, max_length=50, num_beams=5, repetition_penalty=2.0, num_return_sequences=1)
  File "/home/natke/miniconda3/envs/llama/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/natke/miniconda3/envs/llama/lib/python3.9/site-packages/transformers/generation/utils.py", line 1681, in generate
    return self.beam_search(
  File "/home/natke/miniconda3/envs/llama/lib/python3.9/site-packages/transformers/generation/utils.py", line 3020, in beam_search
    outputs = self(
  File "/home/natke/miniconda3/envs/llama/lib/python3.9/site-packages/optimum/modeling_base.py", line 90, in __call__
    return self.forward(*args, **kwargs)
  File "/home/natke/miniconda3/envs/llama/lib/python3.9/site-packages/optimum/onnxruntime/modeling_decoder.py", line 659, in forward
    outputs = self.decoder(
  File "/home/natke/miniconda3/envs/llama/lib/python3.9/site-packages/optimum/onnxruntime/base.py", line 68, in __call__
    return self.forward(*args, **kwargs)
  File "/home/natke/miniconda3/envs/llama/lib/python3.9/site-packages/optimum/onnxruntime/base.py", line 393, in forward
    self.session.run_with_iobinding(io_binding)
  File "/home/natke/miniconda3/envs/llama/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 331, in run_with_iobinding
    self._sess.run_with_iobinding(iobinding._iobinding, run_options)
RuntimeError: Error in execution: Non-zero status code returned while running Add node. Name:'/model/layers.0/self_attn/Add_1' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:171 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, int, const onnxruntime::TensorShape*, OrtValue*&, const onnxruntime::Node&) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{5,32,13,64} Requested shape:{5,4,13,64}

@natke
Copy link

natke commented Oct 6, 2023

@fxmarty

@fxmarty
Copy link
Contributor

fxmarty commented Oct 9, 2023

Thank you indeed, it should have been fixed by the latest commit #1425. Could you give it an other try? Thank you!

@natke
Copy link

natke commented Oct 9, 2023

Will do!

@natke
Copy link

natke commented Oct 9, 2023

That worked - thank you!

@lin-lcx
Copy link

lin-lcx commented Nov 16, 2023

I have the same problem

    model = ORTModelForVision2Seq.from_pretrained('Norm/nougat-latex-base',  export=True, provider="CUDAExecutionProvider").to(device)
    t1 = time.time()
    print("load model1 time = ", (t1 - t0))
    # params = sum([v.numel() for k,v in model.state_dict().items()])
    # init processor
    t0 = time.time()
    tokenizer = NougatTokenizerFast.from_pretrained(r'/home/kas/kas_workspace/linchengxuan/formula_recognition/nougat-latex-ocr-main/examples/data')
    t1 = time.time()
    print("load model2 time = ", (t1 - t0))
    t0 = time.time()
    latex_processor = NougatLaTexProcessor.from_pretrained(r'/home/kas/kas_workspace/linchengxuan/formula_recognition/nougat-latex-ocr-main/examples/data')

    t1 = time.time()
    print("load model3 time = ", (t1 - t0))
    
    # run test
    if args.img_path.endswith('png'):
        image = Image.open(args.img_path)
        if not image.mode == "RGB":
            image = image.convert('RGB')

        pixel_values = latex_processor(image, return_tensors="pt").pixel_values
        task_prompt = tokenizer.bos_token
        decoder_input_ids = tokenizer(task_prompt, add_special_tokens=False,
                                    return_tensors="pt").input_ids
        with torch.no_grad():
            outputs = model.generate(
                pixel_values.to(device),
                decoder_input_ids=decoder_input_ids.to(device),
                max_length=model.decoder.config.max_length,
                early_stopping=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                use_cache=True,
                num_beams=1,
                bad_words_ids=[[tokenizer.unk_token_id]],
                return_dict_in_generate=True,
            )
        sequence = tokenizer.batch_decode(outputs.sequences)[0]
        sequence = sequence.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token,
                                                                                                    "")
        sequence = process_raw_latex_code(sequence)
        print(sequence)

load model1 time = 853.7959609031677
load model2 time = 0.06522369384765625
load model3 time = 0.002882719039916992
0211233.png
/home/kas/.conda/envs/torch/lib/python3.8/site-packages/transformers/generation/utils.py:1473: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )
warnings.warn(
/home/kas/.conda/envs/torch/lib/python3.8/site-packages/transformers/generation/configuration_utils.py:418: UserWarning: num_beams is set to 1. However, early_stopping is set to True -- this flag is only used in beam-based generation modes. You should set num_beams>1 or unset early_stopping.
warnings.warn(
2023-11-16 16:38:08.089559621 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Add node. Name:'/encoder/layers.3/blocks.1/Add_1' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:171 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, int, const onnxruntime::TensorShape*, OrtValue*&, const onnxruntime::Node&) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{1,122,1024} Requested shape:{1,126,1024}
Traceback (most recent call last):
File "run_latex_ocr-Copy1.py", line 130, in
run_nougat_latex()
File "run_latex_ocr-Copy1.py", line 107, in run_nougat_latex
outputs = model.generate(
File "/home/kas/.conda/envs/torch/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/kas/.conda/envs/torch/lib/python3.8/site-packages/transformers/generation/utils.py", line 1548, in generate
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
File "/home/kas/.conda/envs/torch/lib/python3.8/site-packages/transformers/generation/utils.py", line 661, in _prepare_encoder_decoder_kwargs_for_generation
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
File "/home/kas/.conda/envs/torch/lib/python3.8/site-packages/optimum/onnxruntime/base.py", line 68, in call
return self.forward(*args, kwargs)
File "/home/kas/.conda/envs/torch/lib/python3.8/site-packages/optimum/onnxruntime/modeling_seq2seq.py", line 430, in forward
self.session.run_with_iobinding(io_binding)
File "/home/kas/.conda/envs/torch/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 331, in run_with_iobinding
self._sess.run_with_iobinding(iobinding._iobinding, run_options)
RuntimeError: Error in execution: Non-zero status code returned while running Add node. Name:'/encoder/layers.3/blocks.1/Add_1' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:171 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, int, const onnxruntime::TensorShape
, OrtValue
&, const onnxruntime::Node&) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{1,122,1024} Requested shape:{1,126,1024}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants