Skip to content

Commit

Permalink
fix: attention implementation initialization (#9)
Browse files Browse the repository at this point in the history
when initializing the model with no explicit declare which attention implementation to use, the original implementation will throw an error. This is because the llama init function will change the attn implementation to sdpa, which is not implemented in lckv yet. We fix it by passing a copy of the config to the llama init function.
  • Loading branch information
why-in-Shanghaitech committed Nov 19, 2024
1 parent 0dc9a12 commit b639d17
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions models/modeling_lckv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import copy
import math
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -301,7 +302,8 @@ class LCKVLlamaPreTrainedModel(LlamaPreTrainedModel):

class LCKVLlamaModel(LCKVLlamaPreTrainedModel, LlamaModel):
def __init__(self, config: LCKVLlamaConfig):
LlamaModel.__init__(self, config)
LCKVLlamaPreTrainedModel.__init__(self, config)
LlamaModel.__init__(self, copy.deepcopy(config)) # copy config to avoid modifying the original
self.layers = nn.ModuleList([LCKVLlamaDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.parser = LayerTypeParser(config.layer_types)

Expand Down Expand Up @@ -717,7 +719,8 @@ def _update_causal_mask(

class LCKVLlamaForCausalLM(LCKVLlamaPreTrainedModel, LlamaForCausalLM):
def __init__(self, config):
LlamaForCausalLM.__init__(self, config)
LCKVLlamaPreTrainedModel.__init__(self, config)
LlamaForCausalLM.__init__(self, copy.deepcopy(config)) # copy config to avoid modifying the original
self.model = LCKVLlamaModel(config)

# Initialize weights and apply final processing
Expand Down

0 comments on commit b639d17

Please sign in to comment.