Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WarpSpec] improve allocation for smem #7

Merged
merged 4 commits into from
Dec 13, 2024
Merged

WarpSpec] improve allocation for smem #7

merged 4 commits into from
Dec 13, 2024

Conversation

manman-ren
Copy link
Contributor

@manman-ren manman-ren commented Dec 9, 2024

Summary: Attempt to teach Allocation analysis to be aware of warpspec regions. Add a list of regions to each buffer, also teach interference graph to be ware of regions. Currently it makes convert_layout within one consumer to be able to overlap.

Test Plan: Run JFA bwd

Summary: Attempt to teach Allocation analysis to be aware of warpspec regions. Add a list of regions to each buffer, also teach interference graph to be ware of regions.
Currently it makes convert_layout within one consumer to be able to overlap and in the non-persistent case, convert_layout can share with private global buffer.
For persistent, we need to make sure producer doesn't reload the private global buffer for the outer loop (i.e persistent loop) before convert_layout happens in the consumer.

Test Plan: Run JFA bwd

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 9, 2024
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
return true;
}
}
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can one buffer have a region id while the other doesn't, and should that be treated in different regions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can be conservative. I am currently trying to handle the private buffer associated with channels, the checking for "!= 0" i.e ignoring producer warp group is kind of hacky.

: maxId;
}
if (operationId[liveOp] < minId) {
minId = operationId[liveOp];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess isLocalForWS falls through here. How does it make sure the scratch buffer (convert layout) of the second consumer overlaps with that of the first consumer? Is it handled in buildInterferenceGraph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in buildInterferenceGraph, ops with different taskIds will be interfering with each other.

}
if (isPrivateGlobalForWS) {
minId = 0;
maxId = operationId[liveOp] + 1 > maxId ? operationId[liveOp] + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For buffers used inside a loop, should we use the operationId of the loop here which should cover the whole body of the loop? Perhaps the outermost loop body should be used which should handle the persistent loop case.

@@ -548,6 +630,9 @@ class AllocationAnalysis {
xSizeRange.intersects(ySizeRange)) {
interference[x].insert(y);
}
// if x and y belong to different regions (ignore producer region).
if (inDifferentRegion(x, y) && xSizeRange.intersects(yOpRange))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh there is a typo here: xSizeRange.intersects(yOpRange)
-->
xSizeRange.intersects(ySizeRange)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I was looking at that code, I wonder if we should always make x and y intersect, if they are from different regions.

auto tA = A->regionIds;
auto tB = B->regionIds;
for (auto t1 : tA) {
for (auto t2 : tA) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another typo here, should be tB. Will try to clean this up.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@htyu
Copy link
Contributor

htyu commented Dec 13, 2024

It seems that the patch increases SMEM usage for one of the GEMM kernels. I'll take a deeper look.

@manman-ren
Copy link
Contributor Author

It seems that the patch increases SMEM usage for one of the GEMM kernels. I'll take a deeper look.

Oh that is weird.

@manman-ren manman-ren merged commit 05ee274 into ws Dec 13, 2024
1 check passed
htyu pushed a commit that referenced this pull request Dec 13, 2024
Summary: Attempt to teach Allocation analysis to be aware of warpspec
regions. Add a list of regions to each buffer, also teach interference
graph to be ware of regions. Currently it makes convert_layout within
one consumer to be able to overlap.

Test Plan: Run JFA bwd
@htyu
Copy link
Contributor

htyu commented Dec 13, 2024

It seems that the patch increases SMEM usage for one of the GEMM kernels. I'll take a deeper look.

Oh that is weird.

Fixed by #8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants