From 1f9b8b9235b9bfa81d45d35f7e500fc31529b326 Mon Sep 17 00:00:00 2001 From: ksohrab3 Date: Wed, 13 Nov 2024 13:16:18 -0500 Subject: [PATCH 1/3] Support llama3/3.1/3.2 text models --- .../model_executor/layers/rotary_embedding.py | 60 ++++++++++++++++++- sarathi/utils/hf_utils.py | 11 +++- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/sarathi/model_executor/layers/rotary_embedding.py b/sarathi/model_executor/layers/rotary_embedding.py index f6bacc8..a29af2f 100644 --- a/sarathi/model_executor/layers/rotary_embedding.py +++ b/sarathi/model_executor/layers/rotary_embedding.py @@ -179,6 +179,51 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: return cache +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + + smooth * inv_freqs, + ), + ) + return new_freqs + + # Inverse dim formula to find dim based on number of rotations def _yarn_find_correction_dim( num_rotations: int, @@ -311,9 +356,20 @@ def get_rope( head_size, rotary_dim, max_position, base, is_neox_style ) else: - scaling_type = rope_scaling["type"] + scaling_type = rope_scaling["type"] if "type" in rope_scaling else rope_scaling["rope_type"] scaling_factor = rope_scaling["factor"] - if scaling_type == "linear": + if scaling_type == "llama3": + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) + elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor ) diff --git a/sarathi/utils/hf_utils.py b/sarathi/utils/hf_utils.py index a0d644a..4b1f175 100644 --- a/sarathi/utils/hf_utils.py +++ b/sarathi/utils/hf_utils.py @@ -88,6 +88,15 @@ def get_and_verify_max_len( derived_max_model_len = min(derived_max_model_len, max_len_key) rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + if rope_scaling is not None: if derived_max_model_len == float("inf"): raise ValueError( @@ -97,7 +106,7 @@ def get_and_verify_max_len( ) assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] - if rope_scaling["type"] == "yarn": + if rope_type == "yarn": derived_max_model_len = rope_scaling["original_max_position_embeddings"] derived_max_model_len *= scaling_factor From 51eba9ea21e38340006dc6b6874f7b1422b1d0f7 Mon Sep 17 00:00:00 2001 From: ksohrab3 Date: Wed, 13 Nov 2024 16:21:01 -0500 Subject: [PATCH 2/3] Bug fix cast max len to int --- sarathi/utils/hf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sarathi/utils/hf_utils.py b/sarathi/utils/hf_utils.py index 4b1f175..f2680e5 100644 --- a/sarathi/utils/hf_utils.py +++ b/sarathi/utils/hf_utils.py @@ -123,4 +123,4 @@ def get_and_verify_max_len( rope_scaling = {"type": "linear", "factor": scaling_factor} hf_config.rope_scaling = rope_scaling - return max_model_len + return int(max_model_len) \ No newline at end of file From 4ea3cf7d68745876ce4a28c441ecade81b7ece8d Mon Sep 17 00:00:00 2001 From: ksohrab3 Date: Wed, 13 Nov 2024 18:46:35 -0500 Subject: [PATCH 3/3] Make format --- .../model_executor/layers/rotary_embedding.py | 39 ++++++++++++------- sarathi/utils/hf_utils.py | 5 +-- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/sarathi/model_executor/layers/rotary_embedding.py b/sarathi/model_executor/layers/rotary_embedding.py index a29af2f..f206349 100644 --- a/sarathi/model_executor/layers/rotary_embedding.py +++ b/sarathi/model_executor/layers/rotary_embedding.py @@ -197,8 +197,9 @@ def __init__( self.low_freq_factor = low_freq_factor self.high_freq_factor = high_freq_factor self.orig_max_position = orig_max_position - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style + ) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: inv_freqs = super()._compute_inv_freq(base) @@ -207,8 +208,9 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: wave_len = 2 * math.pi / inv_freqs if self.low_freq_factor != self.high_freq_factor: - smooth = (self.orig_max_position / wave_len - self.low_freq_factor - ) / (self.high_freq_factor - self.low_freq_factor) + smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) else: smooth = 0 new_freqs = torch.where( @@ -217,8 +219,7 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: torch.where( wave_len > low_freq_wavelen, inv_freqs / self.scaling_factor, - (1 - smooth) * inv_freqs / self.scaling_factor + - smooth * inv_freqs, + (1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs, ), ) return new_freqs @@ -356,19 +357,27 @@ def get_rope( head_size, rotary_dim, max_position, base, is_neox_style ) else: - scaling_type = rope_scaling["type"] if "type" in rope_scaling else rope_scaling["rope_type"] + scaling_type = ( + rope_scaling["type"] + if "type" in rope_scaling + else rope_scaling["rope_type"] + ) scaling_factor = rope_scaling["factor"] if scaling_type == "llama3": low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] - rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, - scaling_factor, low_freq_factor, - high_freq_factor, - original_max_position) + original_max_position = rope_scaling["original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, + ) elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor diff --git a/sarathi/utils/hf_utils.py b/sarathi/utils/hf_utils.py index f2680e5..7505025 100644 --- a/sarathi/utils/hf_utils.py +++ b/sarathi/utils/hf_utils.py @@ -94,8 +94,7 @@ def get_and_verify_max_len( elif "rope_type" in rope_scaling: rope_type = rope_scaling["rope_type"] else: - raise ValueError( - "rope_scaling must have a 'type' or 'rope_type' key.") + raise ValueError("rope_scaling must have a 'type' or 'rope_type' key.") if rope_scaling is not None: if derived_max_model_len == float("inf"): @@ -123,4 +122,4 @@ def get_and_verify_max_len( rope_scaling = {"type": "linear", "factor": scaling_factor} hf_config.rope_scaling = rope_scaling - return int(max_model_len) \ No newline at end of file + return int(max_model_len)