From 19a168c7f1776168248a861e733c88c1da7c9d06 Mon Sep 17 00:00:00 2001 From: Koan-Sin Tan Date: Sat, 18 May 2024 13:16:58 +0800 Subject: [PATCH] fix the tiny llama conversion issue (#7) * fix the tiny llama conversion issue * fix formatting. --------- Co-authored-by: Advait Jain --- ai_edge_torch/generative/utilities/loader.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ai_edge_torch/generative/utilities/loader.py b/ai_edge_torch/generative/utilities/loader.py index 873e935d..020f2489 100644 --- a/ai_edge_torch/generative/utilities/loader.py +++ b/ai_edge_torch/generative/utilities/loader.py @@ -228,14 +228,14 @@ def _map_attention( q_name = self._names.attn_query_proj.format(idx) k_name = self._names.attn_key_proj.format(idx) v_name = self._names.attn_value_proj.format(idx) - converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv( + converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv( config, state.pop(f"{q_name}.weight"), state.pop(f"{k_name}.weight"), state.pop(f"{v_name}.weight"), ) if config.attn_config.qkv_use_bias: - converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv( + converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv( config, state.pop(f"{q_name}.bias"), state.pop(f"{k_name}.bias"), @@ -243,9 +243,13 @@ def _map_attention( ) o_name = self._names.attn_output_proj.format(idx) - converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight") + converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop( + f"{o_name}.weight" + ) if config.attn_config.output_proj_use_bias: - converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias") + converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop( + f"{o_name}.bias" + ) def _map_norm( self,