diff --git a/src/transformers/models/aria/configuration_aria.py b/src/transformers/models/aria/configuration_aria.py index b9e1705e6f672d..ad0df22c96e276 100644 --- a/src/transformers/models/aria/configuration_aria.py +++ b/src/transformers/models/aria/configuration_aria.py @@ -91,8 +91,8 @@ def __init__( self.image_size = image_size self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps - self.hidden_act = hidden_act self._supports_sdpa = False + self.hidden_act = hidden_act @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 459ff57be17bb8..019ae5440a11f4 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1374,11 +1374,9 @@ def forward(self, input, tokens_per_expert): if torch.cuda.is_available(): torch.cuda.set_device(input.device) original_dtype = input.dtype - return experts_gemm( - input.to(torch.bfloat16), - self.weight.to(torch.bfloat16), - tokens_per_expert - ).to(original_dtype) + return experts_gemm(input.to(torch.bfloat16), self.weight.to(torch.bfloat16), tokens_per_expert).to( + original_dtype + ) class AriaGroupedMLP(nn.Module): @@ -2811,7 +2809,6 @@ def forward( ) - @dataclass class AriaCausalLMOutputWithPast(ModelOutput): """ @@ -3076,4 +3073,3 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 158da72778c0ad..6aa1ba6afb9efe 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -36,8 +36,8 @@ LlamaForCausalLM, LlamaMLP, LlamaModel, + LlamaPreTrainedModel, LlamaRMSNorm, - LlamaPreTrainedModel ) from ..llava.modeling_llava import LlavaCausalLMOutputWithPast from ..siglip.configuration_siglip import SiglipVisionConfig @@ -765,9 +765,10 @@ def __init__( moe_z_loss_coeff: float = 1e-5, moe_aux_loss_coeff: float = 1e-3, moe_num_shared_experts: int = 2, + pad_token_id=2, **kwargs, ): - super().__init__(**kwargs) + super().__init__(pad_token_id=pad_token_id, **kwargs) self.moe_intermediate_size = moe_intermediate_size self.moe_num_experts = moe_num_experts self.moe_topk = moe_topk @@ -863,7 +864,6 @@ def _supports_sdpa(self): return self.language_model._supports_sdpa def _init_weights(self, module): - std = self.config.text_config.initializer_range if hasattr(self.config, 'initializer_range'): std = self.config.initializer_range elif hasattr(self.config, 'text_config'): @@ -880,6 +880,10 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() elif isinstance(module, AriaGroupedGEMM): module.weight.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=std) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() # adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/router.py#L96-L304 @@ -1124,10 +1128,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() elif isinstance(module, AriaGroupedGEMM): module.weight.data.normal_(mean=0.0, std=std) - elif isinstance(module, nn.Conv2d): - module.weight.data.normal_(mean=0.0, std=std) - if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() class AriaTextModel(LlamaModel, AriaTextPreTrainedModel): diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index 5bc8cb3541fbc6..b449dfc6dd98d0 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -287,7 +287,11 @@ def test_sdpa_can_dispatch_on_flash(self): pass @unittest.skip(reason="") - def test_new_cache_format(self): + def test_new_cache_format_1(self): + pass + + @unittest.skip(reason="") + def test_new_cache_format_0(self): pass @unittest.skip(reason="Feedforward chunking is not yet supported")