Skip to content

Commit

Permalink
Test loading generation config with safetensor weights (#31550)
Browse files Browse the repository at this point in the history
fix test
  • Loading branch information
gante authored Jul 9, 2024
1 parent cffa2b9 commit 4c2538b
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,20 +1424,15 @@ def test_pretrained_low_mem_new_config(self):
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)

def test_generation_config_is_loaded_with_model(self):
# Note: `joaogante/tiny-random-gpt2-with-generation-config` has a `generation_config.json` containing a dummy
# `transformers_version` field set to `foo`. If loading the file fails, this test also fails.
# Note: `TinyLlama/TinyLlama-1.1B-Chat-v1.0` has a `generation_config.json` containing `max_length: 2048`

# 1. Load without further parameters
model = AutoModelForCausalLM.from_pretrained(
"joaogante/tiny-random-gpt2-with-generation-config", use_safetensors=False
)
self.assertEqual(model.generation_config.transformers_version, "foo")
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
self.assertEqual(model.generation_config.max_length, 2048)

# 2. Load with `device_map`
model = AutoModelForCausalLM.from_pretrained(
"joaogante/tiny-random-gpt2-with-generation-config", device_map="auto", use_safetensors=False
)
self.assertEqual(model.generation_config.transformers_version, "foo")
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto")
self.assertEqual(model.generation_config.max_length, 2048)

@require_safetensors
def test_safetensors_torch_from_torch(self):
Expand Down

0 comments on commit 4c2538b

Please sign in to comment.