Skip to content

Commit

Permalink
Streamline modular_aria format
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Oct 24, 2024
1 parent 6a8c805 commit ae19ca6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/aria/configuration_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __init__(
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self._supports_sdpa = False
self.hidden_act = hidden_act

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
Expand Down
10 changes: 3 additions & 7 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,11 +1374,9 @@ def forward(self, input, tokens_per_expert):
if torch.cuda.is_available():
torch.cuda.set_device(input.device)
original_dtype = input.dtype
return experts_gemm(
input.to(torch.bfloat16),
self.weight.to(torch.bfloat16),
tokens_per_expert
).to(original_dtype)
return experts_gemm(input.to(torch.bfloat16), self.weight.to(torch.bfloat16), tokens_per_expert).to(
original_dtype
)


class AriaGroupedMLP(nn.Module):
Expand Down Expand Up @@ -2811,7 +2809,6 @@ def forward(
)



@dataclass
class AriaCausalLMOutputWithPast(ModelOutput):
"""
Expand Down Expand Up @@ -3076,4 +3073,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

14 changes: 7 additions & 7 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaPreTrainedModel
)
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast
from ..siglip.configuration_siglip import SiglipVisionConfig
Expand Down Expand Up @@ -765,9 +765,10 @@ def __init__(
moe_z_loss_coeff: float = 1e-5,
moe_aux_loss_coeff: float = 1e-3,
moe_num_shared_experts: int = 2,
pad_token_id=2,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_topk = moe_topk
Expand Down Expand Up @@ -863,7 +864,6 @@ def _supports_sdpa(self):
return self.language_model._supports_sdpa

def _init_weights(self, module):
std = self.config.text_config.initializer_range
if hasattr(self.config, 'initializer_range'):
std = self.config.initializer_range
elif hasattr(self.config, 'text_config'):
Expand All @@ -880,6 +880,10 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, AriaGroupedGEMM):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()


# adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/router.py#L96-L304
Expand Down Expand Up @@ -1124,10 +1128,6 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, AriaGroupedGEMM):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()


class AriaTextModel(LlamaModel, AriaTextPreTrainedModel):
Expand Down
6 changes: 5 additions & 1 deletion tests/models/aria/test_modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,11 @@ def test_sdpa_can_dispatch_on_flash(self):
pass

@unittest.skip(reason="")
def test_new_cache_format(self):
def test_new_cache_format_1(self):
pass

@unittest.skip(reason="")
def test_new_cache_format_0(self):
pass

@unittest.skip(reason="Feedforward chunking is not yet supported")
Expand Down

0 comments on commit ae19ca6

Please sign in to comment.