diff --git a/ai_edge_torch/generative/utilities/loader.py b/ai_edge_torch/generative/utilities/loader.py index 873e935d..020f2489 100644 --- a/ai_edge_torch/generative/utilities/loader.py +++ b/ai_edge_torch/generative/utilities/loader.py @@ -228,14 +228,14 @@ def _map_attention( q_name = self._names.attn_query_proj.format(idx) k_name = self._names.attn_key_proj.format(idx) v_name = self._names.attn_value_proj.format(idx) - converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv( + converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv( config, state.pop(f"{q_name}.weight"), state.pop(f"{k_name}.weight"), state.pop(f"{v_name}.weight"), ) if config.attn_config.qkv_use_bias: - converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv( + converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv( config, state.pop(f"{q_name}.bias"), state.pop(f"{k_name}.bias"), @@ -243,9 +243,13 @@ def _map_attention( ) o_name = self._names.attn_output_proj.format(idx) - converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight") + converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop( + f"{o_name}.weight" + ) if config.attn_config.output_proj_use_bias: - converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias") + converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop( + f"{o_name}.bias" + ) def _map_norm( self,