-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Rotary Positional Embeddings (RoPE) - part 2 of parallel attention blocks #450
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## main #450 +/- ##
==========================================
+ Coverage 69.11% 69.24% +0.13%
==========================================
Files 170 170
Lines 11524 11580 +56
==========================================
+ Hits 7965 8019 +54
- Misses 3559 3561 +2
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a few minor things, mainly around testing and comments
@@ -112,3 +115,38 @@ def test_forward(self, data, emb): | |||
actual = emb(data) | |||
expected = torch.Size([3, 5]) | |||
assert_expected(actual.shape, expected) | |||
|
|||
|
|||
def test_rotary_embeddings_math(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put these unit tests into a class? (Similar to the other tests in this file)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, will do.
return cur_freqs.view(*shape, 2) | ||
|
||
def forward( | ||
self, q: torch.Tensor, k: torch.Tensor, start_pos: Union[int, float] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it makes sense to have start_pos default to 0? (My assumption is that this would at least be the starting point for most users)
Maximum expected sequence length for the model, if exceeded the cached freqs will be recomputed | ||
ratio: int | ||
The ratio for the geometric progression to compute the rotation angles | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to add more in the docstring on the exact details of these embeddings, e.g. at least the [[cos, -sin], [sin, cos]] matrix and maybe even a small example (like the simple 2D one you wrote for the unit test)
assert_expected(qr[0, :, 1], qr2[1, :, 0]) | ||
|
||
assert_expected(kr[0], kr2[0]) | ||
assert_expected(kr[0, :, 1], kr2[1, :, 0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a test for updating the cached frequencies? (As far as I can tell this second test is not hitting that block in L262-268, lmk if I'm misunderstanding)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that's a good idea.
k_ = k.float().reshape(*k.shape[:-1], -1, 2) # B H L D/2 2 | ||
|
||
if isinstance(start_pos, int): | ||
if start_pos + seq_len > self.max_seq_len_cached: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments here about when the frequencies need to be recomputed might be helpful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good - offhand should be changing dtype, changing device, and resetting seq len > max_seq_len.
) | ||
self.compute_freqs_cis(max_position_embeddings) | ||
|
||
def compute_freqs_cis( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Random q: what does cis mean here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's short form for rotation transform technically doing e^(alpha*i) = cos(alpha) + i * sin(alpha), or shortened, cos + i * sin = cis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should probably add that in the docstring actually, otherwise it's too cryptic.
@rohan-varma has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@rohan-varma has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
high level comment, but let's maybe create a modules/layers/embeddings folder in the future as we might have multiple embedding layers.
Summary:
Adds Rotary Positional Embeddings (RoPE)
Test plan:
two unit tests - one for math, one for padding