Skip to content

Commit

Permalink
torch fx compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Feb 2, 2024
1 parent f278527 commit ffbfc0f
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 87 deletions.
30 changes: 13 additions & 17 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.base = base

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings)
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())

def _set_cos_sin_cache(self, seq_len):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
Expand All @@ -140,17 +140,13 @@ def _set_cos_sin_cache(self, seq_len):
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len)

# Move to the target device if needed. This line prevents repeated device copies, if sin/cos haven't changed.
self.sin_cached = self.sin_cached.to(x.device)
self.cos_cached = self.cos_cached.to(x.device)
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:seq_len].to(x),
Expand All @@ -163,9 +159,9 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base)
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim)
Expand All @@ -176,18 +172,18 @@ def _set_cos_sin_cache(self, seq_len):
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device), persistent=False)


class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base)
super().__init__(dim, max_position_embeddings, base, device)

def _set_cos_sin_cache(self, seq_len):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
Expand All @@ -201,8 +197,8 @@ def _set_cos_sin_cache(self, seq_len):
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device), persistent=False)


def rotate_half(x):
Expand Down
165 changes: 95 additions & 70 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,87 +507,112 @@ def test_eager_matches_sdpa_generate(self):
self.assertTrue(torch.allclose(res_eager, res_sdpa))

@require_torch_gpu
def test_rope_casting_invariant(self):
def test_rope_cast_strategy_invariant(self):
"""
Test exclusive to models with RoPE embeddings: tests that the RoPE embeddings are unnafected by the model
casting strategy (`.to()` or `torch_dtype`).
Test exclusive to models with RoPE embeddings: tests that the RoPE embeddings are invariant with respect to the
model cast strategy (`.to()` or `torch_dtype`).
"""
model_1 = LlamaForCausalLM.from_pretrained(
"HuggingFaceM4/tiny-random-LlamaForCausalLM",
device_map="auto",
torch_dtype=torch.bfloat16,
)
model_2 = LlamaForCausalLM.from_pretrained(
"HuggingFaceM4/tiny-random-LlamaForCausalLM",
device_map="auto",
).to(torch.bfloat16)
config = model_1.config

# Tests that a forward pass with inputs *smaller* than the initialized length work as expected.
input_ids = torch.randint(0, config.vocab_size, (1, 1)).to(model_1.device)
model_1_out = model_1(input_ids)
model_2_out = model_2(input_ids)
self.assertTrue(torch.allclose(model_1_out.logits, model_2_out.logits))
self.assertTrue(model_1.model.layers[0].self_attn.rotary_emb.sin_cached.dtype == torch.float32)
self.assertTrue(model_2.model.layers[0].self_attn.rotary_emb.sin_cached.dtype == torch.float32)

# Tests that a forward pass with inputs *larger* than the initialized length work correctly.
input_ids = torch.randint(0, config.vocab_size, (1, config.max_position_embeddings + 1)).to(model_1.device)
model_1_out = model_1(input_ids)
model_2_out = model_2(input_ids)
self.assertTrue(torch.allclose(model_1_out.logits, model_2_out.logits))
self.assertTrue(model_1.model.layers[0].self_attn.rotary_emb.sin_cached.dtype == torch.float32)
self.assertTrue(model_2.model.layers[0].self_attn.rotary_emb.sin_cached.dtype == torch.float32)
for dtype in (torch.float32, torch.float16, torch.bfloat16):
model_1 = LlamaForCausalLM.from_pretrained(
"HuggingFaceM4/tiny-random-LlamaForCausalLM",
device_map="auto",
torch_dtype=dtype,
)
model_2 = LlamaForCausalLM.from_pretrained(
"HuggingFaceM4/tiny-random-LlamaForCausalLM",
device_map="auto",
).to(dtype)
config = model_1.config

# Tests that a forward pass with inputs *smaller* than the initialized length work as expected.
input_ids = torch.randint(0, config.vocab_size, (1, 1)).to(model_1.device)
model_1_out = model_1(input_ids)
model_2_out = model_2(input_ids)
self.assertTrue(torch.allclose(model_1_out.logits, model_2_out.logits))
self.assertTrue(
torch.allclose(
model_1.model.layers[0].self_attn.rotary_emb.sin_cached,
model_2.model.layers[0].self_attn.rotary_emb.sin_cached,
)
)
self.assertTrue(
torch.allclose(
model_1.model.layers[0].self_attn.rotary_emb.cos_cached,
model_2.model.layers[0].self_attn.rotary_emb.cos_cached,
)
)

# Tests that a forward pass with inputs *larger* than the initialized length work correctly.
input_ids = torch.randint(0, config.vocab_size, (1, config.max_position_embeddings + 1)).to(model_1.device)
model_1_out = model_1(input_ids)
model_2_out = model_2(input_ids)
self.assertTrue(torch.allclose(model_1_out.logits, model_2_out.logits))
self.assertTrue(
torch.allclose(
model_1.model.layers[0].self_attn.rotary_emb.sin_cached,
model_2.model.layers[0].self_attn.rotary_emb.sin_cached,
)
)
self.assertTrue(
torch.allclose(
model_1.model.layers[0].self_attn.rotary_emb.cos_cached,
model_2.model.layers[0].self_attn.rotary_emb.cos_cached,
)
)

@require_torch_gpu
def test_rope_device_invariant(self):
def test_rope_initialization_invariant(self):
"""
Test exclusive to models with RoPE embeddings: tests that the RoPE embeddings are unnafected by the device
placement.
Test exclusive to models with RoPE embeddings: tests that the RoPE embeddings are unnafected by the
initialization device and dtype.
"""

def compare_rope(test_dtype, dim, max_position_embeddings, base):
with torch.device("cuda"):
rope_gpu_init = LlamaRotaryEmbedding(
dim=dim, max_position_embeddings=max_position_embeddings, base=base, device="cuda"
)
rope_cpu_init = LlamaRotaryEmbedding(
# gpu init in test_dtype
torch.set_default_dtype(test_dtype)
rope_dtype_gpu_init = LlamaRotaryEmbedding(
dim=dim, max_position_embeddings=max_position_embeddings, base=base, device="cuda"
)
self.assertTrue(rope_dtype_gpu_init.sin_cached.device.type == "cuda")
self.assertTrue(rope_dtype_gpu_init.sin_cached.dtype == test_dtype)

# gpu init in fp32, casted to test_dtype
torch.set_default_dtype(torch.float32)
rope_fp32_gpu_init = LlamaRotaryEmbedding(
dim=dim, max_position_embeddings=max_position_embeddings, base=base, device="cuda"
)
self.assertTrue(rope_fp32_gpu_init.sin_cached.device.type == "cuda")
self.assertTrue(rope_fp32_gpu_init.sin_cached.dtype == torch.float32)
rope_fp32_gpu_init = rope_fp32_gpu_init.to(dtype=test_dtype)
self.assertTrue(rope_fp32_gpu_init.sin_cached.device.type == "cuda")
self.assertTrue(rope_fp32_gpu_init.sin_cached.dtype == test_dtype)

# cpu init in float32, casted to test_dtype and moved to the gpu
rope_fp32_cpu_init = LlamaRotaryEmbedding(
dim=dim, max_position_embeddings=max_position_embeddings, base=base, device="cpu"
)
# Despite the `with` statement and the `device` argument (that exists for backwards compatibility), the
# sin/cos tensors are set on CPU in both cases.
def check_cache(rope_a, rope_b, expected_device="cpu"):
self.assertTrue(expected_device in str(rope_a.sin_cached.device))
self.assertTrue(expected_device in str(rope_b.sin_cached.device))
self.assertTrue(rope_a.sin_cached.dtype == torch.float32)
self.assertTrue(rope_b.sin_cached.dtype == torch.float32)
self.assertTrue(torch.allclose(rope_a.sin_cached, rope_b.sin_cached))
self.assertTrue(torch.allclose(rope_a.cos_cached, rope_b.cos_cached))

check_cache(rope_gpu_init, rope_cpu_init)

# Moving the rope instance with `.to()` does not move the sin/cos tensors to the device nor changes its
# type
rope_cpu_init = rope_cpu_init.to(device="cuda", dtype=test_dtype)
rope_gpu_init = rope_gpu_init.to(device="cuda", dtype=test_dtype)
check_cache(rope_cpu_init, rope_gpu_init)

# However, running the forward pass will move the sin/cos cache to the input device, to prevent repeated
# copies to the target device. The type of the cache remains unchanged, but the output dtype will match
# the input dtype.
self.assertTrue(rope_fp32_cpu_init.sin_cached.device.type == "cpu")
self.assertTrue(rope_fp32_cpu_init.sin_cached.dtype == torch.float32)
rope_fp32_cpu_init = rope_fp32_cpu_init.to(dtype=test_dtype, device="cuda")
self.assertTrue(rope_fp32_cpu_init.sin_cached.device.type == "cuda")
self.assertTrue(rope_fp32_cpu_init.sin_cached.dtype == test_dtype)

# Sanity check 1: the sin/cos tensors should be the same for all initializations (after casting to
# test_dtype)
self.assertTrue(torch.allclose(rope_dtype_gpu_init.sin_cached, rope_fp32_gpu_init.sin_cached))
self.assertTrue(torch.allclose(rope_dtype_gpu_init.sin_cached, rope_fp32_cpu_init.sin_cached))
self.assertTrue(torch.allclose(rope_dtype_gpu_init.cos_cached, rope_fp32_gpu_init.cos_cached))
self.assertTrue(torch.allclose(rope_dtype_gpu_init.cos_cached, rope_fp32_cpu_init.cos_cached))

# Sanity check 2: the output of the forward pass is also the same
test_input = torch.rand((1, 1), device="cuda", dtype=test_dtype)
fwd_cos_cpu_init, fwd_sin_cpu_init = rope_cpu_init(test_input, seq_len=max_position_embeddings)
fwd_cos_gpu_init, fwd_sin_gpu_init = rope_gpu_init(test_input, seq_len=max_position_embeddings)
check_cache(rope_cpu_init, rope_gpu_init, expected_device="cuda")
self.assertTrue("cuda" in str(fwd_cos_cpu_init.device))
self.assertTrue("cuda" in str(fwd_cos_gpu_init.device))
self.assertTrue(fwd_cos_cpu_init.dtype == test_dtype)
self.assertTrue(fwd_cos_gpu_init.dtype == test_dtype)

max_sin_diff = (fwd_sin_cpu_init - fwd_sin_gpu_init).abs().max()
max_cos_diff = (fwd_cos_cpu_init - fwd_cos_gpu_init).abs().max()
max_diff = max(max_sin_diff, max_cos_diff)
self.assertEqual(max_diff, 0.0)
fwd_cos_dtype_gpu, fwd_sin_dtype_gpu = rope_dtype_gpu_init(test_input, seq_len=max_position_embeddings)
fwd_cos_fp32_gpu, fwd_sin_fp32_gpu = rope_fp32_gpu_init(test_input, seq_len=max_position_embeddings)
fwd_cos_fp32_cpu, fwd_sin_fp32_cpu = rope_fp32_cpu_init(test_input, seq_len=max_position_embeddings)
self.assertTrue(torch.allclose(fwd_cos_dtype_gpu, fwd_cos_fp32_gpu))
self.assertTrue(torch.allclose(fwd_cos_dtype_gpu, fwd_cos_fp32_cpu))
self.assertTrue(torch.allclose(fwd_sin_dtype_gpu, fwd_sin_fp32_gpu))
self.assertTrue(torch.allclose(fwd_sin_dtype_gpu, fwd_sin_fp32_cpu))

for test_dtype in (torch.float32, torch.float16, torch.bfloat16):
for dim in (64, 256, 1024):
Expand Down

0 comments on commit ffbfc0f

Please sign in to comment.