Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 20, 2024
1 parent 67c3fcd commit c898d03
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/transformers/models/cohere/modular_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,23 @@
import torch.utils.checkpoint
from torch import nn

from ...cache_utils import Cache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import logging, LossKwargs
from ...utils import LossKwargs, logging
from ..llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRotaryEmbedding,
eager_attention_forward,
)
from .configuration_cohere import CohereConfig

from ..llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, LlamaMLP, LlamaModel, LlamaForCausalLM, eager_attention_forward


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -144,7 +150,9 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm:
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
self.q_norm = CohereLayerNorm(hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps)
self.q_norm = CohereLayerNorm(
hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps
)
self.k_norm = CohereLayerNorm(
hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
)
Expand Down

0 comments on commit c898d03

Please sign in to comment.