Skip to content

Commit

Permalink
fix a few test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed May 14, 2024
1 parent e210081 commit ca05970
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,13 +1344,18 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
self._static_cache = StaticCache(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
try:
self._static_cache = StaticCache(
config=self.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.device,
dtype=cache_dtype,
)
except AttributeError:
raise ValueError(
f"This model's class ({self.__class__.__name__}) does not support static cache for generation."
)
else:
self._static_cache.reset() # reset the cache for a new generation
return self._static_cache
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dbrx/configuration_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def __init__(
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
self.num_key_value_heads = self.attn_config.kv_n_heads

tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
Expand Down
1 change: 1 addition & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,7 @@ def test_new_cache_format(self, num_beams, do_sample):

@require_torch_gpu
@slow
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in generate to be flaky
def test_generate_compile_fullgraph(self):
"""Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results"""
for model_class in self.all_generative_model_classes:
Expand Down
4 changes: 4 additions & 0 deletions tests/models/dbrx/test_modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ def test_model_from_pretrained(self):
def test_tied_weights_keys(self):
pass

@unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.")
def test_generate_compile_fullgraph(self):
pass


@require_torch
class DbrxModelIntegrationTest(unittest.TestCase):
Expand Down
4 changes: 4 additions & 0 deletions tests/models/jamba/test_modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,10 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
def test_new_cache_format(self, num_beams, do_sample):
pass

@unittest.skip("Jamba has its own special cache type")
def test_generate_compile_fullgraph(self):
pass


@require_torch
class JambaModelIntegrationTest(unittest.TestCase):
Expand Down

0 comments on commit ca05970

Please sign in to comment.