From b639d1761679d8b3ba05fbfaa2b024388be888f5 Mon Sep 17 00:00:00 2001 From: whywhy-rtx3090 <43395692+why-in-Shanghaitech@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:45:04 +0800 Subject: [PATCH] fix: attention implementation initialization (#9) when initializing the model with no explicit declare which attention implementation to use, the original implementation will throw an error. This is because the llama init function will change the attn implementation to sdpa, which is not implemented in lckv yet. We fix it by passing a copy of the config to the llama init function. --- models/modeling_lckv.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/models/modeling_lckv.py b/models/modeling_lckv.py index bde4c04..543f8b4 100644 --- a/models/modeling_lckv.py +++ b/models/modeling_lckv.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch LLaMA model.""" +import copy import math from typing import List, Optional, Tuple, Union @@ -301,7 +302,8 @@ class LCKVLlamaPreTrainedModel(LlamaPreTrainedModel): class LCKVLlamaModel(LCKVLlamaPreTrainedModel, LlamaModel): def __init__(self, config: LCKVLlamaConfig): - LlamaModel.__init__(self, config) + LCKVLlamaPreTrainedModel.__init__(self, config) + LlamaModel.__init__(self, copy.deepcopy(config)) # copy config to avoid modifying the original self.layers = nn.ModuleList([LCKVLlamaDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.parser = LayerTypeParser(config.layer_types) @@ -717,7 +719,8 @@ def _update_causal_mask( class LCKVLlamaForCausalLM(LCKVLlamaPreTrainedModel, LlamaForCausalLM): def __init__(self, config): - LlamaForCausalLM.__init__(self, config) + LCKVLlamaPreTrainedModel.__init__(self, config) + LlamaForCausalLM.__init__(self, copy.deepcopy(config)) # copy config to avoid modifying the original self.model = LCKVLlamaModel(config) # Initialize weights and apply final processing