-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/transformer sequence sharding #67
base: develop
Are you sure you want to change the base?
Changes from all commits
e3c5283
e727e3c
a847f1a
ec0ac2d
0aaf0f9
943a0d4
063844b
d1483b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,8 @@ | |
else: | ||
_FLASH_ATTENTION_AVAILABLE = True | ||
|
||
|
||
from anemoi.models.distributed.transformer import halo_exchange | ||
from anemoi.models.distributed.transformer import shard_heads | ||
from anemoi.models.distributed.transformer import shard_sequence | ||
|
||
|
@@ -42,6 +44,7 @@ def __init__( | |
is_causal: bool = False, | ||
window_size: Optional[int] = None, | ||
dropout_p: float = 0.0, | ||
shard_strategy: str = "shard_heads", | ||
): | ||
super().__init__() | ||
|
||
|
@@ -55,38 +58,75 @@ def __init__( | |
self.window_size = (window_size, window_size) # flash attention | ||
self.dropout_p = dropout_p | ||
self.is_causal = is_causal | ||
self.shard_strategy = shard_strategy | ||
|
||
self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) | ||
self.attention = attn_func | ||
|
||
if not _FLASH_ATTENTION_AVAILABLE: | ||
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention") | ||
|
||
if shard_strategy not in ["shard_heads", "shard_sequence"]: | ||
raise ValueError(f"Invalid shard_strategy: {shard_strategy}") | ||
|
||
if shard_strategy == "shard_sequence": # remove this after PR #47 is merged (sliding window support) | ||
assert _FLASH_ATTENTION_AVAILABLE, "Flash attention is required for shard_sequence strategy" | ||
|
||
self.projection = nn.Linear(embed_dim, embed_dim, bias=True) | ||
|
||
def forward( | ||
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None | ||
) -> Tensor: | ||
query, key, value = self.lin_qkv(x).chunk(3, -1) | ||
|
||
if model_comm_group: | ||
assert ( | ||
model_comm_group.size() == 1 or batch_size == 1 | ||
), "Only batch size of 1 is supported when model is sharded accross GPUs" | ||
|
||
query, key, value = ( | ||
einops.rearrange( | ||
t, | ||
"(batch grid) (heads vars) -> batch heads grid vars", | ||
if self.shard_strategy == "shard_sequence": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is now very long. can we introduce e.g. something like `if if self.shard_strategy == "shard_sequence": query, key, value = self.lin_qkv(x).chunk(3, -1) query, key, value = ( if if self.shard_strategy == "shard_heads" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed, the if & else blocks should be refactored as separate (member) functions |
||
assert ( | ||
shapes[-1][0] // 2 >= self.window_size[0] | ||
), "Sharded sequence length must be at least twice the window size" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could have the assert print the sharded sequence length and window size so the user sees the values that raised the error? |
||
|
||
# unpack grid dimension first to allow for halo exchange | ||
x_bgc = einops.rearrange( | ||
x, | ||
"(batch grid) channels -> batch grid channels", | ||
batch=batch_size, | ||
heads=self.num_heads, | ||
) | ||
for t in (query, key, value) | ||
) | ||
|
||
query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) | ||
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) | ||
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) | ||
# communicate halos (adds halos to x) | ||
x_plus_halos, halo_size_left, halo_size_right = halo_exchange( | ||
x_bgc, halo_size=self.window_size[0], mgroup=model_comm_group | ||
) | ||
|
||
# compute q, k, v (on local sequence shards with halos) | ||
query, key, value = self.lin_qkv(x_plus_halos).chunk(3, -1) | ||
|
||
query, key, value = ( | ||
einops.rearrange( | ||
t, | ||
"batch grid (heads vars) -> batch heads grid vars", | ||
heads=self.num_heads, | ||
) | ||
for t in (query, key, value) | ||
) | ||
else: # shard_heads | ||
query, key, value = self.lin_qkv(x).chunk(3, -1) | ||
|
||
query, key, value = ( | ||
einops.rearrange( | ||
t, | ||
"(batch grid) (heads vars) -> batch heads grid vars", | ||
batch=batch_size, | ||
heads=self.num_heads, | ||
) | ||
for t in (query, key, value) | ||
) | ||
|
||
query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) | ||
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) | ||
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) | ||
|
||
dropout_p = self.dropout_p if self.training else 0.0 | ||
|
||
if _FLASH_ATTENTION_AVAILABLE: | ||
|
@@ -104,7 +144,11 @@ def forward( | |
dropout_p=dropout_p, | ||
) # expects (batch heads grid variable) format | ||
|
||
out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) | ||
if self.shard_strategy == "shard_sequence": | ||
out = out[:, :, halo_size_left : out.shape[-2] - halo_size_right, :] # remove halos | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer if this would happen in a function that lives at the same place as halo_exchange, e.g. call halo_expand first and then halo_contract (not best names). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just remove_halos |
||
else: # shard_heads | ||
out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) | ||
|
||
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") | ||
|
||
out = self.projection(out) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,7 @@ def __init__( | |
num_heads: int = 16, | ||
mlp_hidden_ratio: int = 4, | ||
dropout_p: float = 0.1, | ||
shard_strategy: str = "shard_heads", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add to doc string below? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this value configurable? (how can one override the default?) |
||
**kwargs, | ||
) -> None: | ||
"""Initialize TransformerProcessor. | ||
|
@@ -138,6 +139,7 @@ def __init__( | |
window_size=window_size, | ||
activation=activation, | ||
dropout_p=dropout_p, | ||
shard_strategy=shard_strategy, | ||
) | ||
|
||
self.offload_layers(cpu_offload) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering: we now have
halo_exchange
_halo_exchange
_HaloExchange
would it make sense to come up with more unique / more descriptive names for these? I think this might be a bit confusing. I admit that the names for the other routines (shard_heads etc.) are not great either.