Skip to content

Commit

Permalink
fix llama flax
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Apr 29, 2024
1 parent 0f1997c commit 87befb7
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tests/models/llama/test_modeling_flax_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input
)

past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
position_ids = model.get_position_ids_from_attention_mask(
attention_mask_cache, input_ids.shape[0], input_ids.shape[-1] - 1
)

outputs_cache = model(
Expand All @@ -159,7 +159,6 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input
past_key_values=past_key_values,
position_ids=position_ids,
)
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model(
input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
Expand Down

0 comments on commit 87befb7

Please sign in to comment.