-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SDPA not implemented error #9
Comments
Hi! Thank you for your question. I feel confused since such a problem does not exist in Llama, which also supports sdpa... The model loading seems correct, and the first call to All I could come up with are two ad-hoc solutions: def __init__(self, config: LCKVQwen2Config):
+ config._attn_implementation = "eager"
Qwen2Model.__init__(self, config)
self.layers = nn.ModuleList([LCKVQwen2DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.parser = LayerTypeParser(config.layer_types) Option 2: Cheat with an eager implementation but pretend it is sdpa. Modify LCKV_LLAMA_ATTENTION_CLASSES = {
"eager": LCKVLlamaAttention,
"flash_attention_2": LCKVLlamaFlashAttention2,
+ "sdpa": LCKVLlamaAttention,
} I hope it will work. To find the exact cause I may need to see more codes in your notebook. Anyway, I will put lckv sdpa implementation on the agenda. |
I found the exact cause of the problem which is I forgot to delete the model loaded on the kernel. So, I just had to reset the kernel and run again but now, I face another problem which means he reads the custom config but doesn't apply it and uses the parent class of the custom config. Can you tell me what to do in this situation? P.S: I tried making the custom config inherit from PreTrainedConfig but still same error. The error looks like this:
|
Thanks for the reply! LCKV did this by adding a Lines 289 to 290 in cf50b6c
You may also do this: model_class, tokenizer_class = MODEL_CLASSES[model_type]
tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
+ config = LCKVQwen2Config.from_pretrained(model_name_or_path)
+ # do some configurations...
- model = model_class.from_pretrained(model_name_or_path, attn_implementation="eager")
+ model = model_class.from_pretrained(model_name_or_path, config=config, attn_implementation="eager") |
Thanks for your help, I'm grateful and It helped in the config access problems to args and I was inquiring about the prepare_inputs_for_generation in CausalLM Class as It returns a NoneType Object, so, I added a return statement for the model_inputs but the model responded with random gibberish. I hope you can help me with this problem. I will give you access to the notebook if you want to see the whole code and check for yourself as a reply in the email.
|
Hi! I solved this error of CausalLM by returning the model_inputs and adjusting the layer types config to None for default value and it works completely fine. However, I came across a Runtime Error because of a tensor size missmatch in the addition of the causal mask with the attention weights. So, I was wondering if you could help me solve this problem here. This is the layer types config i tried:
This is the Runtime Error i receive after I change the config of the layer types:
|
Yes... I can reproduce this bug. I haven't tested generation with the |
I have just pushed a bugfix. Hopefully, it could work. |
Thank you , the bugfix worked and I could change middle layers to the layers ahead |
Thanks, I just tried returning the model_inputs and the bugfix helped in using different layer configs and It works just fine. |
when initializing the model with no explicit declare which attention implementation to use, the original implementation will throw an error. This is because the llama init function will change the attn implementation to sdpa, which is not implemented in lckv yet. We fix it by passing a copy of the config to the llama init function.
I'm trying to use LCKV caching technique on Qwen2 LLM and I noticed that you didn't implement sdpa through torch. So, I tried when loading the model to use
attn_implementation = "eager"
but still it didn't work at all because when i inherit the model it uses its own Config and see that the support for sdpa is equal to True.Do you have any idea how to solve this problem?
This is the error log from console below:
ValueError Traceback (most recent call last)
Cell In[83], line 34
32 model_class, tokenizer_class = MODEL_CLASSES[model_type]
33 tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
---> 34 model = model_class.from_pretrained(model_name_or_path, attn_implementation="eager")
35 model.eval() # Set model to evaluation mode
37 # Prepare Input
File /opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py:3886, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
3880 config = cls._autoset_attn_implementation(
3881 config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
3882 )
3884 with ContextManagers(init_contexts):
3885 # Let's make sure we don't run the init function of buffer modules
-> 3886 model = cls(config, *model_args, **model_kwargs)
3888 # make sure we use the model's config since the init call might have copied it
3889 config = model.config
Cell In[75], line 4, in LCKVQwen2ForCausalLM.init(self, config)
2 def init(self, config):
3 Qwen2ForCausalLM.init(self, config)
----> 4 self.model = LCKVQwen2Model(config)
6 # Initialize weights and apply final processing
7 self.post_init()
Cell In[74], line 3, in LCKVQwen2Model.init(self, config)
2 def init(self, config: LCKVQwen2Config):
----> 3 Qwen2Model.init(self, config)
4 self.layers = nn.ModuleList([LCKVQwen2DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
5 self.parser = LayerTypeParser(config.layer_types)
File /opt/conda/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:864, in Qwen2Model.init(self, config)
863 def init(self, config: Qwen2Config):
--> 864 super().init(config)
865 self.padding_idx = config.pad_token_id
866 self.vocab_size = config.vocab_size
File /opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py:1404, in PreTrainedModel.init(self, config, *inputs, **kwargs)
1398 raise ValueError(
1399 f"Parameter config in
{self.__class__.__name__}(config)
should be an instance of class "1400 "
PretrainedConfig
. To create a model from a pretrained model use "1401 f"
model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)
"1402 )
1403 # Save config and origin of the pretrained weights if given in model
-> 1404 config = self._autoset_attn_implementation(
1405 config, torch_dtype=torch.get_default_dtype(), check_device_map=False
1406 )
1407 self.config = config
1409 self.name_or_path = config.name_or_path
File /opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py:1581, in PreTrainedModel._autoset_attn_implementation(cls, config, use_flash_attention_2, torch_dtype, device_map, check_device_map)
1572 cls._check_and_enable_flash_attn_2(
1573 config,
1574 torch_dtype=torch_dtype,
(...)
1577 check_device_map=check_device_map,
1578 )
1579 elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
1580 # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
-> 1581 config = cls._check_and_enable_sdpa(
1582 config,
1583 hard_check_only=False if requested_attn_implementation is None else True,
1584 )
1586 if (
1587 torch.version.hip is not None
1588 and config._attn_implementation == "sdpa"
1589 and torch.cuda.device_count() > 1
1590 ):
1591 logger.warning_once(
1592 "Using the
SDPA
attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."1593 )
File /opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py:1776, in PreTrainedModel._check_and_enable_sdpa(cls, config, hard_check_only)
1774 if hard_check_only:
1775 if not cls._supports_sdpa:
-> 1776 raise ValueError(
1777 f"{cls.name} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
1778 " Please request the support for this architecture: huggingface/transformers#28005. If you believe"
1779 ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument
attn_implementation="eager"
meanwhile. Example:model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")
'1780 )
1781 if not is_torch_sdpa_available():
1782 raise ImportError(
1783 "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
1784 )
ValueError: LCKVQwen2Model does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: huggingface/transformers#28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument
attn_implementation="eager"
meanwhile. Example:model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")
The text was updated successfully, but these errors were encountered: