diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index 5538ed415c4935..b3483103461e9a 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -28,17 +28,23 @@ import torch.utils.checkpoint from torch import nn +from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import logging, LossKwargs +from ...utils import LossKwargs, logging +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaMLP, + LlamaModel, + LlamaRotaryEmbedding, + eager_attention_forward, +) from .configuration_cohere import CohereConfig -from ..llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, LlamaMLP, LlamaModel, LlamaForCausalLM, eager_attention_forward - logger = logging.get_logger(__name__) @@ -144,7 +150,9 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads - self.q_norm = CohereLayerNorm(hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps) + self.q_norm = CohereLayerNorm( + hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps + ) self.k_norm = CohereLayerNorm( hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps )