Skip to content

Commit

Permalink
Yarn extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Jun 27, 2024
1 parent 09904de commit bfa3a74
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions server/text_generation_server/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def static(cls, config, dim, base, device):
)
elif rope_scaling["type"] == "yarn":
scaling_factor = rope_scaling["factor"]
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[
Expand All @@ -109,6 +111,8 @@ def static(cls, config, dim, base, device):
attn_factor=1,
beta_fast=32,
beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
)
elif rope_scaling["type"] == "su":
short_factor = torch.tensor(
Expand Down Expand Up @@ -181,6 +185,8 @@ def load(cls, config, prefix, weights):
scaling_factor=scaling_factor,
)
elif rope_scaling["type"] == "yarn":
mscale = rope_scaling.get("mscale", 1.0)
mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=rope_scaling[
Expand All @@ -193,6 +199,8 @@ def load(cls, config, prefix, weights):
attn_factor=1,
beta_fast=32,
beta_slow=1,
mscale=mscale,
mscale_all_dim=mscale_all_dim,
)
else:
raise NotImplementedError(
Expand Down Expand Up @@ -346,10 +354,10 @@ def linear_ramp_mask(min, max, dim):
return ramp_func


def get_mscale(scale=1):
def get_mscale(scale=1, mscale: float = 1.0):
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
return 0.1 * mscale * math.log(scale) + 1.0


class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
Expand All @@ -365,6 +373,8 @@ def __init__(
attn_factor,
beta_fast,
beta_slow,
mscale: float,
mscale_all_dim: float,
):
inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor)
Expand All @@ -376,7 +386,9 @@ def __init__(
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor
get_mscale(self.scaling_factor, mscale)
/ get_mscale(self.scaling_factor, mscale_all_dim)
* self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation

def _update_cos_sin_cache(self, dtype, device, seqlen):
Expand Down Expand Up @@ -409,9 +421,9 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
)

self.inv_freq = inv_freq
self.mscale = float(
get_mscale(self.scaling_factor) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
# self.mscale = float(
# get_mscale(self.scaling_factor) * self.attn_factor
# ) # Get n-d magnitude scaling corrected for interpolation

self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
Expand Down

0 comments on commit bfa3a74

Please sign in to comment.