From ffbfc0f9e58855deca1c54096f0c66b90cf824f8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 2 Feb 2024 17:13:15 +0000 Subject: [PATCH] torch fx compatible --- .../models/llama/modeling_llama.py | 30 ++-- tests/models/llama/test_modeling_llama.py | 165 ++++++++++-------- 2 files changed, 108 insertions(+), 87 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 6a2c6020d75a8f..28cb7b0c2b9d61 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -128,9 +128,9 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.base = base # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings) + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype()) - def _set_cos_sin_cache(self, seq_len): + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim) @@ -140,17 +140,13 @@ def _set_cos_sin_cache(self, seq_len): freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() + self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len) - - # Move to the target device if needed. This line prevents repeated device copies, if sin/cos haven't changed. - self.sin_cached = self.sin_cached.to(x.device) - self.cos_cached = self.cos_cached.to(x.device) + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(x), @@ -163,9 +159,9 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base) + super().__init__(dim, max_position_embeddings, base, device) - def _set_cos_sin_cache(self, seq_len): + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim) @@ -176,8 +172,8 @@ def _set_cos_sin_cache(self, seq_len): freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() + self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device), persistent=False) class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): @@ -185,9 +181,9 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base) + super().__init__(dim, max_position_embeddings, base, device) - def _set_cos_sin_cache(self, seq_len): + def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: base = self.base * ( @@ -201,8 +197,8 @@ def _set_cos_sin_cache(self, seq_len): freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos() - self.sin_cached = emb.sin() + self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device), persistent=False) def rotate_half(x): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 852e4de34d4719..f81d5ae6de0b1a 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -507,87 +507,112 @@ def test_eager_matches_sdpa_generate(self): self.assertTrue(torch.allclose(res_eager, res_sdpa)) @require_torch_gpu - def test_rope_casting_invariant(self): + def test_rope_cast_strategy_invariant(self): """ - Test exclusive to models with RoPE embeddings: tests that the RoPE embeddings are unnafected by the model - casting strategy (`.to()` or `torch_dtype`). + Test exclusive to models with RoPE embeddings: tests that the RoPE embeddings are invariant with respect to the + model cast strategy (`.to()` or `torch_dtype`). """ - model_1 = LlamaForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", - device_map="auto", - torch_dtype=torch.bfloat16, - ) - model_2 = LlamaForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", - device_map="auto", - ).to(torch.bfloat16) - config = model_1.config - - # Tests that a forward pass with inputs *smaller* than the initialized length work as expected. - input_ids = torch.randint(0, config.vocab_size, (1, 1)).to(model_1.device) - model_1_out = model_1(input_ids) - model_2_out = model_2(input_ids) - self.assertTrue(torch.allclose(model_1_out.logits, model_2_out.logits)) - self.assertTrue(model_1.model.layers[0].self_attn.rotary_emb.sin_cached.dtype == torch.float32) - self.assertTrue(model_2.model.layers[0].self_attn.rotary_emb.sin_cached.dtype == torch.float32) - - # Tests that a forward pass with inputs *larger* than the initialized length work correctly. - input_ids = torch.randint(0, config.vocab_size, (1, config.max_position_embeddings + 1)).to(model_1.device) - model_1_out = model_1(input_ids) - model_2_out = model_2(input_ids) - self.assertTrue(torch.allclose(model_1_out.logits, model_2_out.logits)) - self.assertTrue(model_1.model.layers[0].self_attn.rotary_emb.sin_cached.dtype == torch.float32) - self.assertTrue(model_2.model.layers[0].self_attn.rotary_emb.sin_cached.dtype == torch.float32) + for dtype in (torch.float32, torch.float16, torch.bfloat16): + model_1 = LlamaForCausalLM.from_pretrained( + "HuggingFaceM4/tiny-random-LlamaForCausalLM", + device_map="auto", + torch_dtype=dtype, + ) + model_2 = LlamaForCausalLM.from_pretrained( + "HuggingFaceM4/tiny-random-LlamaForCausalLM", + device_map="auto", + ).to(dtype) + config = model_1.config + + # Tests that a forward pass with inputs *smaller* than the initialized length work as expected. + input_ids = torch.randint(0, config.vocab_size, (1, 1)).to(model_1.device) + model_1_out = model_1(input_ids) + model_2_out = model_2(input_ids) + self.assertTrue(torch.allclose(model_1_out.logits, model_2_out.logits)) + self.assertTrue( + torch.allclose( + model_1.model.layers[0].self_attn.rotary_emb.sin_cached, + model_2.model.layers[0].self_attn.rotary_emb.sin_cached, + ) + ) + self.assertTrue( + torch.allclose( + model_1.model.layers[0].self_attn.rotary_emb.cos_cached, + model_2.model.layers[0].self_attn.rotary_emb.cos_cached, + ) + ) + + # Tests that a forward pass with inputs *larger* than the initialized length work correctly. + input_ids = torch.randint(0, config.vocab_size, (1, config.max_position_embeddings + 1)).to(model_1.device) + model_1_out = model_1(input_ids) + model_2_out = model_2(input_ids) + self.assertTrue(torch.allclose(model_1_out.logits, model_2_out.logits)) + self.assertTrue( + torch.allclose( + model_1.model.layers[0].self_attn.rotary_emb.sin_cached, + model_2.model.layers[0].self_attn.rotary_emb.sin_cached, + ) + ) + self.assertTrue( + torch.allclose( + model_1.model.layers[0].self_attn.rotary_emb.cos_cached, + model_2.model.layers[0].self_attn.rotary_emb.cos_cached, + ) + ) @require_torch_gpu - def test_rope_device_invariant(self): + def test_rope_initialization_invariant(self): """ - Test exclusive to models with RoPE embeddings: tests that the RoPE embeddings are unnafected by the device - placement. + Test exclusive to models with RoPE embeddings: tests that the RoPE embeddings are unnafected by the + initialization device and dtype. """ def compare_rope(test_dtype, dim, max_position_embeddings, base): - with torch.device("cuda"): - rope_gpu_init = LlamaRotaryEmbedding( - dim=dim, max_position_embeddings=max_position_embeddings, base=base, device="cuda" - ) - rope_cpu_init = LlamaRotaryEmbedding( + # gpu init in test_dtype + torch.set_default_dtype(test_dtype) + rope_dtype_gpu_init = LlamaRotaryEmbedding( + dim=dim, max_position_embeddings=max_position_embeddings, base=base, device="cuda" + ) + self.assertTrue(rope_dtype_gpu_init.sin_cached.device.type == "cuda") + self.assertTrue(rope_dtype_gpu_init.sin_cached.dtype == test_dtype) + + # gpu init in fp32, casted to test_dtype + torch.set_default_dtype(torch.float32) + rope_fp32_gpu_init = LlamaRotaryEmbedding( + dim=dim, max_position_embeddings=max_position_embeddings, base=base, device="cuda" + ) + self.assertTrue(rope_fp32_gpu_init.sin_cached.device.type == "cuda") + self.assertTrue(rope_fp32_gpu_init.sin_cached.dtype == torch.float32) + rope_fp32_gpu_init = rope_fp32_gpu_init.to(dtype=test_dtype) + self.assertTrue(rope_fp32_gpu_init.sin_cached.device.type == "cuda") + self.assertTrue(rope_fp32_gpu_init.sin_cached.dtype == test_dtype) + + # cpu init in float32, casted to test_dtype and moved to the gpu + rope_fp32_cpu_init = LlamaRotaryEmbedding( dim=dim, max_position_embeddings=max_position_embeddings, base=base, device="cpu" ) - # Despite the `with` statement and the `device` argument (that exists for backwards compatibility), the - # sin/cos tensors are set on CPU in both cases. - def check_cache(rope_a, rope_b, expected_device="cpu"): - self.assertTrue(expected_device in str(rope_a.sin_cached.device)) - self.assertTrue(expected_device in str(rope_b.sin_cached.device)) - self.assertTrue(rope_a.sin_cached.dtype == torch.float32) - self.assertTrue(rope_b.sin_cached.dtype == torch.float32) - self.assertTrue(torch.allclose(rope_a.sin_cached, rope_b.sin_cached)) - self.assertTrue(torch.allclose(rope_a.cos_cached, rope_b.cos_cached)) - - check_cache(rope_gpu_init, rope_cpu_init) - - # Moving the rope instance with `.to()` does not move the sin/cos tensors to the device nor changes its - # type - rope_cpu_init = rope_cpu_init.to(device="cuda", dtype=test_dtype) - rope_gpu_init = rope_gpu_init.to(device="cuda", dtype=test_dtype) - check_cache(rope_cpu_init, rope_gpu_init) - - # However, running the forward pass will move the sin/cos cache to the input device, to prevent repeated - # copies to the target device. The type of the cache remains unchanged, but the output dtype will match - # the input dtype. + self.assertTrue(rope_fp32_cpu_init.sin_cached.device.type == "cpu") + self.assertTrue(rope_fp32_cpu_init.sin_cached.dtype == torch.float32) + rope_fp32_cpu_init = rope_fp32_cpu_init.to(dtype=test_dtype, device="cuda") + self.assertTrue(rope_fp32_cpu_init.sin_cached.device.type == "cuda") + self.assertTrue(rope_fp32_cpu_init.sin_cached.dtype == test_dtype) + + # Sanity check 1: the sin/cos tensors should be the same for all initializations (after casting to + # test_dtype) + self.assertTrue(torch.allclose(rope_dtype_gpu_init.sin_cached, rope_fp32_gpu_init.sin_cached)) + self.assertTrue(torch.allclose(rope_dtype_gpu_init.sin_cached, rope_fp32_cpu_init.sin_cached)) + self.assertTrue(torch.allclose(rope_dtype_gpu_init.cos_cached, rope_fp32_gpu_init.cos_cached)) + self.assertTrue(torch.allclose(rope_dtype_gpu_init.cos_cached, rope_fp32_cpu_init.cos_cached)) + + # Sanity check 2: the output of the forward pass is also the same test_input = torch.rand((1, 1), device="cuda", dtype=test_dtype) - fwd_cos_cpu_init, fwd_sin_cpu_init = rope_cpu_init(test_input, seq_len=max_position_embeddings) - fwd_cos_gpu_init, fwd_sin_gpu_init = rope_gpu_init(test_input, seq_len=max_position_embeddings) - check_cache(rope_cpu_init, rope_gpu_init, expected_device="cuda") - self.assertTrue("cuda" in str(fwd_cos_cpu_init.device)) - self.assertTrue("cuda" in str(fwd_cos_gpu_init.device)) - self.assertTrue(fwd_cos_cpu_init.dtype == test_dtype) - self.assertTrue(fwd_cos_gpu_init.dtype == test_dtype) - - max_sin_diff = (fwd_sin_cpu_init - fwd_sin_gpu_init).abs().max() - max_cos_diff = (fwd_cos_cpu_init - fwd_cos_gpu_init).abs().max() - max_diff = max(max_sin_diff, max_cos_diff) - self.assertEqual(max_diff, 0.0) + fwd_cos_dtype_gpu, fwd_sin_dtype_gpu = rope_dtype_gpu_init(test_input, seq_len=max_position_embeddings) + fwd_cos_fp32_gpu, fwd_sin_fp32_gpu = rope_fp32_gpu_init(test_input, seq_len=max_position_embeddings) + fwd_cos_fp32_cpu, fwd_sin_fp32_cpu = rope_fp32_cpu_init(test_input, seq_len=max_position_embeddings) + self.assertTrue(torch.allclose(fwd_cos_dtype_gpu, fwd_cos_fp32_gpu)) + self.assertTrue(torch.allclose(fwd_cos_dtype_gpu, fwd_cos_fp32_cpu)) + self.assertTrue(torch.allclose(fwd_sin_dtype_gpu, fwd_sin_fp32_gpu)) + self.assertTrue(torch.allclose(fwd_sin_dtype_gpu, fwd_sin_fp32_cpu)) for test_dtype in (torch.float32, torch.float16, torch.bfloat16): for dim in (64, 256, 1024):