Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/transformer sequence sharding #67

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions models/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Keep it human-readable, your future self will thank you!
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)
- Add sequence sharding strategy for TransformerProcessor [#90](https://github.com/ecmwf/anemoi-models/pull/90)

### Changed
- Bugfixes for CI
Expand Down
123 changes: 123 additions & 0 deletions models/src/anemoi/models/distributed/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# nor does it submit to any jurisdiction.


import logging
from typing import Optional

import torch
Expand All @@ -17,6 +18,8 @@

from anemoi.models.distributed.utils import get_memory_format

LOGGER = logging.getLogger(__name__)


def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor:
"""Apply all_to_all along the head dimension.
Expand Down Expand Up @@ -82,6 +85,72 @@ def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = N
return torch.cat(output_list, dim=-3).contiguous(memory_format=input_format)


def _halo_exchange(input_: Tensor, halo_size: int, mgroup: ProcessGroup, bwd: bool = False) -> Tensor:
"""Exchange halo regions between neighboring ranks.

Expected format is (batch_size, halo_size + sequence_length + halo_size, channels).

Parameters
----------
input_ : Tensor
Input tensor
halo_size : int
Halo size (left, right)
mgroup : ProcessGroup
Model communication group
bwd : bool
Flag to indicate if backward pass

Returns
-------
Tensor
Tensor with halo regions from neighboring ranks
"""
end = input_.shape[-2]

left_halo_slice = slice(0, halo_size)
right_halo_slice = slice(end - halo_size, end)
left_send_slice = slice(halo_size, 2 * halo_size)
right_send_slice = slice(end - 2 * halo_size, end - halo_size)

if bwd: # reverse halo exchange direction for gradient accumulation
left_halo_slice, left_send_slice = left_send_slice, left_halo_slice
right_halo_slice, right_send_slice = right_send_slice, right_halo_slice

left_send = input_[:, left_send_slice, :]
right_send = input_[:, right_send_slice, :]

# setup neighbor ranks and tensor lists for all_to_all communication
group_rank = dist.get_rank(mgroup)
group_size = dist.get_world_size(mgroup)
left_rank = group_rank - 1 if group_rank > 0 else None
right_rank = group_rank + 1 if group_rank < group_size - 1 else None

input_list = [torch.empty(0, device=input_.device) for _ in range(group_size)]
if left_rank is not None:
input_list[left_rank] = left_send
if right_rank is not None:
input_list[right_rank] = right_send
output_list = [torch.empty_like(input_i, device=input_.device) for input_i in input_list]

dist.all_to_all(output_list, input_list, group=mgroup)

if bwd: # add gradient contributions to halo regions and zero out send regions
if left_rank is not None:
input_[:, left_send_slice, :] = 0
input_[:, left_halo_slice, :] += output_list[left_rank]
if right_rank is not None:
input_[:, right_send_slice, :] = 0
input_[:, right_halo_slice, :] += output_list[right_rank]
else: # add halo regions to input tensor
if left_rank is not None:
input_[:, left_halo_slice, :] = output_list[left_rank]
if right_rank is not None:
input_[:, right_halo_slice, :] = output_list[right_rank]

return input_


def shard_heads(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor:
"""Sync tensor.

Expand Down Expand Up @@ -130,6 +199,36 @@ def shard_sequence(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor
return _SplitSequenceParallelSection.apply(input_, shapes, mgroup)


def halo_exchange(x: Tensor, halo_size: int, mgroup: ProcessGroup) -> Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering: we now have
halo_exchange
_halo_exchange
_HaloExchange

would it make sense to come up with more unique / more descriptive names for these? I think this might be a bit confusing. I admit that the names for the other routines (shard_heads etc.) are not great either.

"""Exchange halo regions between ranks,

Parameters
----------
x : Tensor
Input tensor
halo_size : int
Halo size (left, right)
mgroup : ProcessGroup
Model communication group

Returns
-------
Tensor, int, int
Tensor appended with halo regions from neighboring ranks, left halo size, right halo size
"""
if mgroup is None or dist.get_world_size(mgroup) == 1:
return x, 0, 0

# pad tensor with halo regions
halo_size_left = halo_size if dist.get_rank(mgroup) != 0 else 0
halo_size_right = halo_size if dist.get_rank(mgroup) != dist.get_world_size(mgroup) - 1 else 0
x_pad = torch.nn.functional.pad(x, pad=(0, 0, halo_size_left, halo_size_right), mode="constant", value=0)

out = _HaloExchange.apply(x_pad, halo_size, mgroup)

return out, halo_size_left, halo_size_right


class _SplitHeadsParallelSection(torch.autograd.Function):
"""Sync the input from parallel section."""

Expand Down Expand Up @@ -172,3 +271,27 @@ def backward(ctx, grad_output):
None,
)
return grad_output, None, None


class _HaloExchange(torch.autograd.Function):
"""Exchange halo regions between ranks."""

@staticmethod
def forward(ctx, input_, halo_size_, mgroup_):
ctx.halo_size = halo_size_
ctx.mgroup = mgroup_

if mgroup_:
return _halo_exchange(input_, halo_size_, mgroup_)
return input_

@staticmethod
def backward(ctx, grad_output):
if ctx.mgroup:
return (
_halo_exchange(grad_output, ctx.halo_size, ctx.mgroup, bwd=True),
None,
None,
)

return grad_output, None, None
70 changes: 57 additions & 13 deletions models/src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
else:
_FLASH_ATTENTION_AVAILABLE = True


from anemoi.models.distributed.transformer import halo_exchange
from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

Expand All @@ -42,6 +44,7 @@ def __init__(
is_causal: bool = False,
window_size: Optional[int] = None,
dropout_p: float = 0.0,
shard_strategy: str = "shard_heads",
):
super().__init__()

Expand All @@ -55,38 +58,75 @@ def __init__(
self.window_size = (window_size, window_size) # flash attention
self.dropout_p = dropout_p
self.is_causal = is_causal
self.shard_strategy = shard_strategy

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")

if shard_strategy not in ["shard_heads", "shard_sequence"]:
raise ValueError(f"Invalid shard_strategy: {shard_strategy}")

if shard_strategy == "shard_sequence": # remove this after PR #47 is merged (sliding window support)
assert _FLASH_ATTENTION_AVAILABLE, "Flash attention is required for shard_sequence strategy"

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
query, key, value = self.lin_qkv(x).chunk(3, -1)

if model_comm_group:
assert (
model_comm_group.size() == 1 or batch_size == 1
), "Only batch size of 1 is supported when model is sharded accross GPUs"

query, key, value = (
einops.rearrange(
t,
"(batch grid) (heads vars) -> batch heads grid vars",
if self.shard_strategy == "shard_sequence":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is now very long. can we introduce e.g. something like

`if if self.shard_strategy == "shard_sequence":
x = self.shard_sequence(x)

query, key, value = self.lin_qkv(x).chunk(3, -1)

query, key, value = (
einops.rearrange(
t,
"(batch grid) (heads vars) -> batch heads grid vars",
batch=batch_size,
heads=self.num_heads,
)
for t in (query, key, value)
)

if if self.shard_strategy == "shard_heads"
query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
.
.
.
.
`

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, the if & else blocks should be refactored as separate (member) functions

assert (
shapes[-1][0] // 2 >= self.window_size[0]
), "Sharded sequence length must be at least twice the window size"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could have the assert print the sharded sequence length and window size so the user sees the values that raised the error?


# unpack grid dimension first to allow for halo exchange
x_bgc = einops.rearrange(
x,
"(batch grid) channels -> batch grid channels",
batch=batch_size,
heads=self.num_heads,
)
for t in (query, key, value)
)

query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
# communicate halos (adds halos to x)
x_plus_halos, halo_size_left, halo_size_right = halo_exchange(
x_bgc, halo_size=self.window_size[0], mgroup=model_comm_group
)

# compute q, k, v (on local sequence shards with halos)
query, key, value = self.lin_qkv(x_plus_halos).chunk(3, -1)

query, key, value = (
einops.rearrange(
t,
"batch grid (heads vars) -> batch heads grid vars",
heads=self.num_heads,
)
for t in (query, key, value)
)
else: # shard_heads
query, key, value = self.lin_qkv(x).chunk(3, -1)

query, key, value = (
einops.rearrange(
t,
"(batch grid) (heads vars) -> batch heads grid vars",
batch=batch_size,
heads=self.num_heads,
)
for t in (query, key, value)
)

query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)

dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
Expand All @@ -104,7 +144,11 @@ def forward(
dropout_p=dropout_p,
) # expects (batch heads grid variable) format

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
if self.shard_strategy == "shard_sequence":
out = out[:, :, halo_size_left : out.shape[-2] - halo_size_right, :] # remove halos
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if this would happen in a function that lives at the same place as halo_exchange, e.g. call halo_expand first and then halo_contract (not best names).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just remove_halos

else: # shard_heads
out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)

out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")

out = self.projection(out)
Expand Down
2 changes: 2 additions & 0 deletions models/src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
activation: str,
window_size: int,
dropout_p: float = 0.0,
shard_strategy: str = "shard_heads",
):
super().__init__()

Expand All @@ -87,6 +88,7 @@ def __init__(
bias=False,
is_causal=False,
dropout_p=dropout_p,
shard_strategy=shard_strategy,
)

self.mlp = nn.Sequential(
Expand Down
2 changes: 2 additions & 0 deletions models/src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
dropout_p: float = 0.0,
shard_strategy: str = "shard_heads",
) -> None:
"""Initialize TransformerProcessor.

Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
shard_strategy=shard_strategy,
)

def forward(
Expand Down
2 changes: 2 additions & 0 deletions models/src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
shard_strategy: str = "shard_heads",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add to doc string below?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this value configurable? (how can one override the default?)

**kwargs,
) -> None:
"""Initialize TransformerProcessor.
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
window_size=window_size,
activation=activation,
dropout_p=dropout_p,
shard_strategy=shard_strategy,
)

self.offload_layers(cpu_offload)
Expand Down
Loading