Skip to content

Commit

Permalink
fix the tiny llama conversion issue (#7)
Browse files Browse the repository at this point in the history
* fix the tiny llama conversion issue

* fix formatting.

---------

Co-authored-by: Advait Jain <[email protected]>
  • Loading branch information
freedomtan and advaitjain authored May 18, 2024
1 parent 2d4e18e commit 19a168c
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions ai_edge_torch/generative/utilities/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,24 +228,28 @@ def _map_attention(
q_name = self._names.attn_query_proj.format(idx)
k_name = self._names.attn_key_proj.format(idx)
v_name = self._names.attn_value_proj.format(idx)
converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
config,
state.pop(f"{q_name}.weight"),
state.pop(f"{k_name}.weight"),
state.pop(f"{v_name}.weight"),
)
if config.attn_config.qkv_use_bias:
converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
config,
state.pop(f"{q_name}.bias"),
state.pop(f"{k_name}.bias"),
state.pop(f"{v_name}.bias"),
)

o_name = self._names.attn_output_proj.format(idx)
converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
f"{o_name}.weight"
)
if config.attn_config.output_proj_use_bias:
converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
f"{o_name}.bias"
)

def _map_norm(
self,
Expand Down

0 comments on commit 19a168c

Please sign in to comment.