Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix jit model #566

Merged
merged 21 commits into from
Mar 19, 2024
Merged

fix jit model #566

merged 21 commits into from
Mar 19, 2024

Conversation

jiqing-feng
Copy link
Collaborator

@jiqing-feng jiqing-feng commented Feb 18, 2024

Hi @echarlaix . This PR fix the jit model:

  1. Enable model_dtype to generate the correct pkv data type in forward
  2. Enable _dtype to generate the correct pkv data type in prepare jit inputs
  3. Use our own prepare_inputs_for_generation to make sure the jit trace work.
  4. Remove autocast in the class as users can use it outside, and we shouldn't use it if the model type is bf16 or fp16

Would you please help me to review it? Thx!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jiqing-feng jiqing-feng mentioned this pull request Feb 19, 2024
Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thaks for the fix!

Comment on lines 324 to 325
self.model_dtype = kwargs.get("model_dtype", self.dtype)
self._dtype = self.model_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should deprecate the model_dtype attribute (redundant with _dtype), adding a warning that it will be removed after v1.18.0

    @property
    def model_dtype(self):
        # add a warning 
        return self._dtype 

Comment on lines 349 to 350
self._reorder_cache = self.model_cls._reorder_cache.__get__(self)
self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to keep it as _reorder_cache and prepare_inputs_for_generation can be different depending on the modeling

if hasattr(self.model_cls, "_convert_to_standard_cache"):
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
if warmup:
self._init_warmup()

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove

@@ -215,14 +216,6 @@ def to(self, device: Union[torch.device, str]):
def can_generate(self):
return isinstance(self, GenerationMixin)

def _call_model(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also might make sense to add a test to verify this doesn't break support for model traced with with autocasting enabled cc @ofirzaf do you know if there any tiny model on thus hub we can use for this (https://huggingface.co/Intel/q8_tiny_starcoder_py/ ?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use this one, yes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we actually should have a test for this to verify no modifications will break support, would you mind adding it @jiqing-feng in https://github.com/huggingface/optimum-intel/blob/v1.15.2/tests/ipex/test_modeling.py#L201 ? Also given @ofirzaf above explanations removing _call_model will likely broke support for q8_tiny_starcoder_py

@ofirzaf
Copy link
Contributor

ofirzaf commented Feb 19, 2024

What this PR is trying to fix?

The autocasting was added here because autocasting with a traced model is not optional, either the model runs with autocasting or the model runs without, it is an attribute of the model and not something the user can decide upon.

Regarding prepare_inputs_for_generation, the current implementation derives the correct implementation from the original's model architecture. This enables a wider support for more model architectures and scalability to support new model architectures out of the box. The proposed implementation might not support a wide variety of models and disable support for assisted generation methods.

@jiqing-feng
Copy link
Collaborator Author

jiqing-feng commented Feb 20, 2024

Hi @echarlaix @ofirzaf . What I did is based on the latest version of transformers and try to enable llama jit model.

  1. I rewrite prepare_inputs_for_generation because this function in llama will failed if we use jit because jit model do not have attr layer, see the following traceback.
  2. Remove autocast is because pytorch do not recommend use autocast on a bf16 model, and it also avoid CI failed at here. The transformers_outputs.logits is fp32 and the outputs.logits is bf16 because of the autocast.

Here is my script and traceback:

import torch
from transformers import AutoTokenizer, pipeline
from optimum.intel import IPEXModelForCausalLM

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, torch_dtype=torch.bfloat16)
generation_kwargs = {"do_sample": False, "max_new_tokens": 32, "num_beams": 4}
tokenizer = AutoTokenizer.from_pretrained(model_id)
text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

print(text_generator("I am happy today because ", **generation_kwargs))

traceback

Traceback (most recent call last):
  File "/home/jiqingfe/test.py", line 13, in <module>
    print(text_generator("I am happy today because ", **generation_kwargs))
  File "/home/jiqingfe/transformers/src/transformers/pipelines/text_generation.py", line 241, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/home/jiqingfe/transformers/src/transformers/pipelines/base.py", line 1196, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/home/jiqingfe/transformers/src/transformers/pipelines/base.py", line 1203, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/home/jiqingfe/transformers/src/transformers/pipelines/base.py", line 1102, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/jiqingfe/transformers/src/transformers/pipelines/text_generation.py", line 328, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/jiqingfe/transformers/src/transformers/generation/utils.py", line 1603, in generate
    return self.beam_search(
  File "/home/jiqingfe/transformers/src/transformers/generation/utils.py", line 2967, in beam_search
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  File "/home/jiqingfe/transformers/src/transformers/models/llama/modeling_llama.py", line 1237, in prepare_inputs_for_generation
    if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
  File "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.10/site-packages/torch/jit/_script.py", line 823, in __getattr__
    return super().__getattr__(attr)
  File "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.10/site-packages/torch/jit/_script.py", line 530, in __getattr__
    return super().__getattr__(attr)
  File "/home/jiqingfe/miniconda3/envs/ipex/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'RecursiveScriptModule' object has no attribute 'layers'

@ofirzaf
Copy link
Contributor

ofirzaf commented Feb 20, 2024

@jiqing-feng I think the problem might be in the tracing process. the layers attribute is part of the new KVCache abstraction in HF/transformers. I think you need to make sure that you don't hit that if in the tracing and then it will be ignored and then you shouldn't even execute it when you infer with the traced model (need to check that). Anyway, if you use the proposed implementation for prepare_inputs_for_generation, you limit the usability of IPEXModels to a narrow familiy of models

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix @ofirzaf Thanks for your review.

  1. For assisted decoding, I have fixed it with minor changes, it now supports assisted decoding.
  2. For prepare_inputs_for_generation and reorder_cahche and so on... . We didn't expect the IPEXModel to support all models, only some model types in our support list should be enough. There is no way to avoid using prepare_inputs_for_generation and reorder_cahche when the model is traced to graph mode, so I can only copy these functions from BaseModelForCausalLM to make sure the generation works.

Hope you understand, thx!

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix @ofirzaf . Thanks for your fix by adding _prepare_inputs_for_generation_for_llama, it is great!

I have some minor changes to support low-precision(bf16) model, would you please help me to review it? Thx!

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for iterating on it @jiqing-feng !

Comment on lines 417 to 416
position_ids = position_ids[:, -1].unsqueeze(-1)
position_ids = position_ids[:, -input_ids.shape[-1] :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this modification needed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_ids should always have the same size as input_ids, we cannot assume the length is 1 while pkv exists (for example, assisted decoding).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a test that would fail without this fix / pass with it then ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be taken care by prepare_inputs_for_generation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also an option, WDYT @echarlaix ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's already integrated in prepare_inputs_for_generation for each modeling
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1237, could make sense to remove it all actually

@@ -215,14 +216,6 @@ def to(self, device: Union[torch.device, str]):
def can_generate(self):
return isinstance(self, GenerationMixin)

def _call_model(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we actually should have a test for this to verify no modifications will break support, would you mind adding it @jiqing-feng in https://github.com/huggingface/optimum-intel/blob/v1.15.2/tests/ipex/test_modeling.py#L201 ? Also given @ofirzaf above explanations removing _call_model will likely broke support for q8_tiny_starcoder_py

@jiqing-feng
Copy link
Collaborator Author

jiqing-feng commented Feb 28, 2024

Hi @echarlaix refer to this comment, do you mind making it clear what kind of tests should I add? Thx!
Besides, I am not saying that we cannot use autocast, I propose that the autocast should be outside the model like:

with torch.autocast(device.type, data_type):
    model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, torch_dtype=torch.bfloat16)
    model(**inputs)

Adding autocast inside will cause unexpected issues, see here.

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix , I have reverted "_call_model", would you please help to review and merge it? Thx!

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix . Do you mind having a look at this PR? I think it is ready to merge.

@jiqing-feng
Copy link
Collaborator Author

Hi @ofirzaf , Do you mind confirming whether it could be merged? Thx!

@ofirzaf
Copy link
Contributor

ofirzaf commented Mar 7, 2024

LGTM

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix . I have added the testing for assisted decoding which will input both past_key_values and input_ids with length>1.

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix @ofirzaf . You were right, we don't need to prepare postion_ids in the forward because prepare_inputs_for_generation already did it. Would you please review it again? It should be ready to merge : ) Thanks!

@echarlaix echarlaix merged commit 9813f90 into huggingface:main Mar 19, 2024
9 of 10 checks passed
@jiqing-feng jiqing-feng deleted the jit branch October 9, 2024 03:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants