From e3c52838de926cdaaa9c85e0d5ec173023eb24da Mon Sep 17 00:00:00 2001 From: japols Date: Tue, 29 Oct 2024 10:49:53 +0000 Subject: [PATCH 1/5] feat: Initial transformer sequence sharding version --- src/anemoi/models/distributed/transformer.py | 115 +++++++++++++++++++ src/anemoi/models/layers/attention.py | 70 ++++++++--- 2 files changed, 172 insertions(+), 13 deletions(-) diff --git a/src/anemoi/models/distributed/transformer.py b/src/anemoi/models/distributed/transformer.py index 78691bba..aadd7e3f 100644 --- a/src/anemoi/models/distributed/transformer.py +++ b/src/anemoi/models/distributed/transformer.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. +import logging from typing import Optional import torch @@ -17,6 +18,8 @@ from anemoi.models.distributed.utils import get_memory_format +LOGGER = logging.getLogger(__name__) + def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor: """Apply all_to_all along the head dimension. @@ -82,6 +85,71 @@ def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = N return torch.cat(output_list, dim=-3).contiguous(memory_format=input_format) +def _halo_exchange(input_: Tensor, halo_size: int, mgroup: ProcessGroup, bwd: bool = False) -> Tensor: + """Exchange halo regions between ranks. + + Send halo regions to the left and right and receive from the right and left and extend the input tensor. + Expected format is (batch_size, sequence_length, channels). + + Parameters + ---------- + input_ : Tensor + Input tensor + halo_size : int + Halo size (left, right) + mgroup : ProcessGroup + Model communication group + bwd : bool + Flag to indicate if backward pass + + Returns + ------- + Tensor + Extended input tensor + """ + end = input_.shape[-2] + + left_halo = input_[:, :halo_size, :] + right_halo = input_[:, end - halo_size :, :] + + left_send = input_[:, halo_size : 2 * halo_size, :] + right_send = input_[:, end - 2 * halo_size : end - halo_size, :] + + if bwd: # reverse halo exchange + left_halo, left_send = left_send, left_halo + right_halo, right_send = right_send, right_halo + + my_rank = dist.get_rank(mgroup) + group_size = dist.get_world_size(mgroup) + left_rank = dist.get_rank(mgroup) - 1 + right_rank = dist.get_rank(mgroup) + 1 + + if my_rank % 2 != 0: + # send left (can't be rank 0) + dist.send(left_send, left_rank, group=mgroup) + # receive left + dist.recv(left_halo, left_rank, group=mgroup) + + if my_rank != group_size - 1: + # send right + dist.send(right_send, right_rank, group=mgroup) + # receive right + dist.recv(right_halo, right_rank, group=mgroup) + else: + if my_rank != group_size - 1: + # receive right + dist.recv(right_halo, right_rank, group=mgroup) + # send right + dist.send(right_send, right_rank, group=mgroup) + if my_rank != 0: + # receive left + dist.recv(left_halo, left_rank, group=mgroup) + # send left + dist.send(left_send, left_rank, group=mgroup) + + return input_ + + def shard_heads(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor: """Sync tensor. @@ -130,6 +198,28 @@ def shard_sequence(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor return _SplitSequenceParallelSection.apply(input_, shapes, mgroup) +def halo_exchange(x: Tensor, halo_size: int, mgroup: ProcessGroup) -> None: + """Exchange halo regions between ranks, + + Parameters + ---------- + x : Tensor + Input tensor + halo_size : int + Halo size (left, right) + mgroup : ProcessGroup + Model communication group + """ + # pad tensor with halo regions + halo_size_left = halo_size if (mgroup and dist.get_rank(mgroup) != 0) else 0 + halo_size_right = halo_size if (mgroup and dist.get_rank(mgroup) != dist.get_world_size(mgroup) - 1) else 0 + x_pad = torch.nn.functional.pad(x, pad=(0, 0, halo_size_left, halo_size_right), mode="constant", value=0) + + out = _HaloExchange.apply(x_pad, halo_size, mgroup) + + return out, halo_size_left, halo_size_right + + class _SplitHeadsParallelSection(torch.autograd.Function): """Sync the input from parallel section.""" @@ -172,3 +262,28 @@ def backward(ctx, grad_output): None, ) return grad_output, None, None + + +class _HaloExchange(torch.autograd.Function): + """Exchange halo regions between ranks.""" + + @staticmethod + def forward(ctx, input_, halo_size_, mgroup_): + ctx.halo_size = halo_size_ + ctx.mgroup = mgroup_ + + if mgroup_: + return _halo_exchange(input_, halo_size_, mgroup_) + return input_ + + @staticmethod + def backward(ctx, grad_output): + if ctx.mgroup: + return ( + # not sure if this works yet, need to test + _halo_exchange(grad_output, ctx.halo_size, ctx.mgroup, bwd=True), + None, + None, + ) + + return grad_output, None, None diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index d7f54920..921522b6 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -25,6 +25,8 @@ else: _FLASH_ATTENTION_AVAILABLE = True + +from anemoi.models.distributed.transformer import halo_exchange from anemoi.models.distributed.transformer import shard_heads from anemoi.models.distributed.transformer import shard_sequence @@ -42,6 +44,7 @@ def __init__( is_causal: bool = False, window_size: Optional[int] = None, dropout_p: float = 0.0, + strategy: str = "shard_heads", ): super().__init__() @@ -55,6 +58,7 @@ def __init__( self.window_size = (window_size, window_size) # flash attention self.dropout_p = dropout_p self.is_causal = is_causal + self.strategy = strategy self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) self.attention = attn_func @@ -62,31 +66,68 @@ def __init__( if not _FLASH_ATTENTION_AVAILABLE: LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention") + if strategy not in ["shard_heads", "shard_sequence"]: + raise ValueError(f"Invalid strategy: {strategy}") + self.projection = nn.Linear(embed_dim, embed_dim, bias=True) def forward( self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None ) -> Tensor: - query, key, value = self.lin_qkv(x).chunk(3, -1) - if model_comm_group: assert ( model_comm_group.size() == 1 or batch_size == 1 ), "Only batch size of 1 is supported when model is sharded accross GPUs" - query, key, value = ( - einops.rearrange( - t, - "(batch grid) (heads vars) -> batch heads grid vars", + if self.strategy == "shard_sequence": + assert _FLASH_ATTENTION_AVAILABLE, "Flash attention is required for shard_sequence strategy" + assert ( + shapes[-1][0] // 2 >= self.window_size[0] + ), "Sharded sequence length must be at least twice the window size" + + # unpack grid dimension first to allow for halo exchange + x_bgc = einops.rearrange( + x, + "(batch grid) channels -> batch grid channels", batch=batch_size, - heads=self.num_heads, ) - for t in (query, key, value) - ) - query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) - key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) - value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) + # communicate halos (adds halos to x) + x_plus_halos, halo_size_left, halo_size_right = halo_exchange( + x_bgc, halo_size=self.window_size[0], mgroup=model_comm_group + ) + + # compute q, k, v (on local sequence shards) + query, key, value = self.lin_qkv(x_plus_halos).chunk(3, -1) + + # further unpack feature dimension + query, key, value = ( + einops.rearrange( + t, + "batch grid (heads vars) -> batch heads grid vars", + heads=self.num_heads, + ) + for t in (query, key, value) + ) + + else: + query, key, value = self.lin_qkv(x).chunk(3, -1) + + query, key, value = ( + einops.rearrange( + t, + "(batch grid) (heads vars) -> batch heads grid vars", + batch=batch_size, + heads=self.num_heads, + ) + for t in (query, key, value) + ) + + if self.strategy == "shard_heads": + query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) + key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) + value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) + dropout_p = self.dropout_p if self.training else 0.0 if _FLASH_ATTENTION_AVAILABLE: @@ -104,7 +145,10 @@ def forward( dropout_p=dropout_p, ) # expects (batch heads grid variable) format - out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) + if self.strategy == "shard_sequence": + out = out[:, :, halo_size_left : out.shape[-2] - halo_size_right, :] # remove halos + if self.strategy == "shard_heads": + out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") out = self.projection(out) From e727e3c34be59cf1f8bf8993d1df13acbc5179e0 Mon Sep 17 00:00:00 2001 From: japols Date: Tue, 29 Oct 2024 13:03:51 +0000 Subject: [PATCH 2/5] feat: shard_strategy configurable via config.model.processor --- src/anemoi/models/layers/attention.py | 16 ++++++++-------- src/anemoi/models/layers/block.py | 2 ++ src/anemoi/models/layers/chunk.py | 2 ++ src/anemoi/models/layers/processor.py | 2 ++ 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index 921522b6..6c9f91e3 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -44,7 +44,7 @@ def __init__( is_causal: bool = False, window_size: Optional[int] = None, dropout_p: float = 0.0, - strategy: str = "shard_heads", + shard_strategy: str = "shard_heads", ): super().__init__() @@ -58,7 +58,7 @@ def __init__( self.window_size = (window_size, window_size) # flash attention self.dropout_p = dropout_p self.is_causal = is_causal - self.strategy = strategy + self.shard_strategy = shard_strategy self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias) self.attention = attn_func @@ -66,8 +66,8 @@ def __init__( if not _FLASH_ATTENTION_AVAILABLE: LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention") - if strategy not in ["shard_heads", "shard_sequence"]: - raise ValueError(f"Invalid strategy: {strategy}") + if shard_strategy not in ["shard_heads", "shard_sequence"]: + raise ValueError(f"Invalid shard_strategy: {shard_strategy}") self.projection = nn.Linear(embed_dim, embed_dim, bias=True) @@ -79,7 +79,7 @@ def forward( model_comm_group.size() == 1 or batch_size == 1 ), "Only batch size of 1 is supported when model is sharded accross GPUs" - if self.strategy == "shard_sequence": + if self.shard_strategy == "shard_sequence": assert _FLASH_ATTENTION_AVAILABLE, "Flash attention is required for shard_sequence strategy" assert ( shapes[-1][0] // 2 >= self.window_size[0] @@ -123,7 +123,7 @@ def forward( for t in (query, key, value) ) - if self.strategy == "shard_heads": + if self.shard_strategy == "shard_heads": query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) @@ -145,9 +145,9 @@ def forward( dropout_p=dropout_p, ) # expects (batch heads grid variable) format - if self.strategy == "shard_sequence": + if self.shard_strategy == "shard_sequence": out = out[:, :, halo_size_left : out.shape[-2] - halo_size_right, :] # remove halos - if self.strategy == "shard_heads": + if self.shard_strategy == "shard_heads": out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 60446d6c..90f9979d 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -69,6 +69,7 @@ def __init__( activation: str, window_size: int, dropout_p: float = 0.0, + shard_strategy: str = "shard_heads", ): super().__init__() @@ -87,6 +88,7 @@ def __init__( bias=False, is_causal=False, dropout_p=dropout_p, + shard_strategy=shard_strategy, ) self.mlp = nn.Sequential( diff --git a/src/anemoi/models/layers/chunk.py b/src/anemoi/models/layers/chunk.py index 5c4fae38..494fb614 100644 --- a/src/anemoi/models/layers/chunk.py +++ b/src/anemoi/models/layers/chunk.py @@ -75,6 +75,7 @@ def __init__( mlp_hidden_ratio: int = 4, activation: str = "GELU", dropout_p: float = 0.0, + shard_strategy: str = "shard_heads", ) -> None: """Initialize TransformerProcessor. @@ -103,6 +104,7 @@ def __init__( activation=activation, window_size=window_size, dropout_p=dropout_p, + shard_strategy=shard_strategy, ) def forward( diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 4fd32311..6d729c7b 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -97,6 +97,7 @@ def __init__( num_heads: int = 16, mlp_hidden_ratio: int = 4, dropout_p: float = 0.1, + shard_strategy: str = "shard_heads", **kwargs, ) -> None: """Initialize TransformerProcessor. @@ -138,6 +139,7 @@ def __init__( window_size=window_size, activation=activation, dropout_p=dropout_p, + shard_strategy=shard_strategy, ) self.offload_layers(cpu_offload) From ec0ac2d20ad78afb9d18eced02ebf83ad4d6f674 Mon Sep 17 00:00:00 2001 From: japols Date: Wed, 11 Dec 2024 14:26:06 +0000 Subject: [PATCH 3/5] feat: configurable halo comm strategies (for benchmarking) --- src/anemoi/models/distributed/transformer.py | 117 ++++++++++++++----- src/anemoi/models/layers/attention.py | 7 +- 2 files changed, 88 insertions(+), 36 deletions(-) diff --git a/src/anemoi/models/distributed/transformer.py b/src/anemoi/models/distributed/transformer.py index aadd7e3f..2982dbbb 100644 --- a/src/anemoi/models/distributed/transformer.py +++ b/src/anemoi/models/distributed/transformer.py @@ -9,6 +9,7 @@ import logging +import os from typing import Optional import torch @@ -109,43 +110,95 @@ def _halo_exchange(input_: Tensor, halo_size: int, mgroup: ProcessGroup, bwd: bo """ end = input_.shape[-2] - left_halo = input_[:, :halo_size, :] - right_halo = input_[:, end - halo_size :, :] + left_halo_slice = slice(0, halo_size) + right_halo_slice = slice(end - halo_size, end) + left_send_slice = slice(halo_size, 2 * halo_size) + right_send_slice = slice(end - 2 * halo_size, end - halo_size) - left_send = input_[:, halo_size : 2 * halo_size, :] - right_send = input_[:, end - 2 * halo_size : end - halo_size, :] + if bwd: # reverse halo exchange direction + left_halo_slice, left_send_slice = left_send_slice, left_halo_slice + right_halo_slice, right_send_slice = right_send_slice, right_halo_slice - if bwd: # reverse halo exchange - left_halo, left_send = left_send, left_halo - right_halo, right_send = right_send, right_halo + left_send = input_[:, left_send_slice, :] + right_send = input_[:, right_send_slice, :] + left_halo = torch.empty_like(right_send, device=input_.device) + right_halo = torch.empty_like(left_send, device=input_.device) - my_rank = dist.get_rank(mgroup) + global_rank = dist.get_rank() + local_rank = dist.get_rank(mgroup) group_size = dist.get_world_size(mgroup) - left_rank = dist.get_rank(mgroup) - 1 - right_rank = dist.get_rank(mgroup) + 1 - - if my_rank % 2 != 0: - # send left (can't be rank 0) - dist.send(left_send, left_rank, group=mgroup) - # receive left - dist.recv(left_halo, left_rank, group=mgroup) - - if my_rank != group_size - 1: - # send right - dist.send(right_send, right_rank, group=mgroup) - # receive right - dist.recv(right_halo, right_rank, group=mgroup) + left_rank = global_rank - 1 if local_rank > 0 else None + right_rank = global_rank + 1 if local_rank < group_size - 1 else None + + match os.environ.get("HALO_COMM", "SENDRECV"): + case "SENDRECV": + if local_rank % 2 != 0: + if left_rank is not None: + dist.send(left_send, left_rank, group=mgroup) + dist.recv(left_halo, left_rank, group=mgroup) + if right_rank is not None: + dist.send(right_send, right_rank, group=mgroup) + dist.recv(right_halo, right_rank, group=mgroup) + else: + if right_rank is not None: + dist.recv(right_halo, right_rank, group=mgroup) + dist.send(right_send, right_rank, group=mgroup) + if left_rank is not None: + dist.recv(left_halo, left_rank, group=mgroup) + dist.send(left_send, left_rank, group=mgroup) + case "ISENDRECV": + reqs = [] + if local_rank % 2 != 0: + if left_rank is not None: + reqs.append(dist.isend(left_send, left_rank, group=mgroup)) + reqs.append(dist.irecv(left_halo, left_rank, group=mgroup)) + if right_rank is not None: + reqs.append(dist.isend(right_send, right_rank, group=mgroup)) + reqs.append(dist.irecv(right_halo, right_rank, group=mgroup)) + else: + if right_rank is not None: + reqs.append(dist.irecv(right_halo, right_rank, group=mgroup)) + reqs.append(dist.isend(right_send, right_rank, group=mgroup)) + if left_rank is not None: + reqs.append(dist.irecv(left_halo, left_rank, group=mgroup)) + reqs.append(dist.isend(left_send, left_rank, group=mgroup)) + for req in reqs: + req.wait() + case "ALLGATHER": + combined_send = torch.cat([left_send, right_send], dim=1).contiguous() + halos = [torch.empty_like(combined_send) for _ in range(group_size)] + dist.all_gather(halos, combined_send, group=mgroup) + left_halo = halos[local_rank - 1][:, halo_size:, :] if local_rank > 0 else None + right_halo = halos[local_rank + 1][:, :halo_size, :] if local_rank < group_size - 1 else None + case "ALLTOALL": + input_list = [torch.empty(1, device=input_.device) for _ in range(group_size)] + if left_rank is not None: + input_list[left_rank] = left_send + if right_rank is not None: + input_list[right_rank] = right_send + output_list = [torch.empty_like(input_i, device=input_.device) for input_i in input_list] + dist.all_to_all(output_list, input_list, group=mgroup) + + if left_rank is not None: + left_halo = output_list[left_rank] + if right_rank is not None: + right_halo = output_list[right_rank] + case _: + raise ValueError(f"Unknown halo communication strategy {os.environ['HALO_COMM']}") + + if bwd: + # remove gradient contribution from send regions and add halo regions + if left_rank is not None: + input_[:, left_send_slice, :] = 0 + input_[:, left_halo_slice, :] += left_halo + if right_rank is not None: + input_[:, right_send_slice, :] = 0 + input_[:, right_halo_slice, :] += right_halo else: - if my_rank != group_size - 1: - # receive right - dist.recv(right_halo, right_rank, group=mgroup) - # send right - dist.send(right_send, right_rank, group=mgroup) - if my_rank != 0: - # receive left - dist.recv(left_halo, left_rank, group=mgroup) - # send left - dist.send(left_send, left_rank, group=mgroup) + if left_rank is not None: + input_[:, left_halo_slice, :] = left_halo + if right_rank is not None: + input_[:, right_halo_slice, :] = right_halo return input_ diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index 6c9f91e3..fec41260 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -109,8 +109,7 @@ def forward( ) for t in (query, key, value) ) - - else: + else: # shard_heads query, key, value = self.lin_qkv(x).chunk(3, -1) query, key, value = ( @@ -123,7 +122,6 @@ def forward( for t in (query, key, value) ) - if self.shard_strategy == "shard_heads": query = shard_heads(query, shapes=shapes, mgroup=model_comm_group) key = shard_heads(key, shapes=shapes, mgroup=model_comm_group) value = shard_heads(value, shapes=shapes, mgroup=model_comm_group) @@ -147,8 +145,9 @@ def forward( if self.shard_strategy == "shard_sequence": out = out[:, :, halo_size_left : out.shape[-2] - halo_size_right, :] # remove halos - if self.shard_strategy == "shard_heads": + else: # shard_heads out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group) + out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)") out = self.projection(out) From 0aaf0f9d0aafe504ec9a6f95b6625b8760c75cd2 Mon Sep 17 00:00:00 2001 From: japols Date: Tue, 17 Dec 2024 18:27:35 +0000 Subject: [PATCH 4/5] cleanup, use all_to_all for halo exchange --- src/anemoi/models/distributed/transformer.py | 115 ++++++------------- src/anemoi/models/layers/attention.py | 7 +- 2 files changed, 39 insertions(+), 83 deletions(-) diff --git a/src/anemoi/models/distributed/transformer.py b/src/anemoi/models/distributed/transformer.py index 2982dbbb..f9310768 100644 --- a/src/anemoi/models/distributed/transformer.py +++ b/src/anemoi/models/distributed/transformer.py @@ -9,7 +9,6 @@ import logging -import os from typing import Optional import torch @@ -87,10 +86,9 @@ def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = N def _halo_exchange(input_: Tensor, halo_size: int, mgroup: ProcessGroup, bwd: bool = False) -> Tensor: - """Exchange halo regions between ranks. + """Exchange halo regions between neighboring ranks. - Send halo regions to the left and right and receive from the right and left and extend the input tensor. - Expected format is (batch_size, sequence_length, channels). + Expected format is (batch_size, halo_size + sequence_length + halo_size, channels). Parameters ---------- @@ -106,7 +104,7 @@ def _halo_exchange(input_: Tensor, halo_size: int, mgroup: ProcessGroup, bwd: bo Returns ------- Tensor - Extended input tensor + Tensor with halo regions from neighboring ranks """ end = input_.shape[-2] @@ -115,90 +113,40 @@ def _halo_exchange(input_: Tensor, halo_size: int, mgroup: ProcessGroup, bwd: bo left_send_slice = slice(halo_size, 2 * halo_size) right_send_slice = slice(end - 2 * halo_size, end - halo_size) - if bwd: # reverse halo exchange direction + if bwd: # reverse halo exchange direction for gradient accumulation left_halo_slice, left_send_slice = left_send_slice, left_halo_slice right_halo_slice, right_send_slice = right_send_slice, right_halo_slice left_send = input_[:, left_send_slice, :] right_send = input_[:, right_send_slice, :] - left_halo = torch.empty_like(right_send, device=input_.device) - right_halo = torch.empty_like(left_send, device=input_.device) - global_rank = dist.get_rank() - local_rank = dist.get_rank(mgroup) + # setup neighbor ranks and tensor lists for all_to_all communication + group_rank = dist.get_rank(mgroup) group_size = dist.get_world_size(mgroup) - left_rank = global_rank - 1 if local_rank > 0 else None - right_rank = global_rank + 1 if local_rank < group_size - 1 else None - - match os.environ.get("HALO_COMM", "SENDRECV"): - case "SENDRECV": - if local_rank % 2 != 0: - if left_rank is not None: - dist.send(left_send, left_rank, group=mgroup) - dist.recv(left_halo, left_rank, group=mgroup) - if right_rank is not None: - dist.send(right_send, right_rank, group=mgroup) - dist.recv(right_halo, right_rank, group=mgroup) - else: - if right_rank is not None: - dist.recv(right_halo, right_rank, group=mgroup) - dist.send(right_send, right_rank, group=mgroup) - if left_rank is not None: - dist.recv(left_halo, left_rank, group=mgroup) - dist.send(left_send, left_rank, group=mgroup) - case "ISENDRECV": - reqs = [] - if local_rank % 2 != 0: - if left_rank is not None: - reqs.append(dist.isend(left_send, left_rank, group=mgroup)) - reqs.append(dist.irecv(left_halo, left_rank, group=mgroup)) - if right_rank is not None: - reqs.append(dist.isend(right_send, right_rank, group=mgroup)) - reqs.append(dist.irecv(right_halo, right_rank, group=mgroup)) - else: - if right_rank is not None: - reqs.append(dist.irecv(right_halo, right_rank, group=mgroup)) - reqs.append(dist.isend(right_send, right_rank, group=mgroup)) - if left_rank is not None: - reqs.append(dist.irecv(left_halo, left_rank, group=mgroup)) - reqs.append(dist.isend(left_send, left_rank, group=mgroup)) - for req in reqs: - req.wait() - case "ALLGATHER": - combined_send = torch.cat([left_send, right_send], dim=1).contiguous() - halos = [torch.empty_like(combined_send) for _ in range(group_size)] - dist.all_gather(halos, combined_send, group=mgroup) - left_halo = halos[local_rank - 1][:, halo_size:, :] if local_rank > 0 else None - right_halo = halos[local_rank + 1][:, :halo_size, :] if local_rank < group_size - 1 else None - case "ALLTOALL": - input_list = [torch.empty(1, device=input_.device) for _ in range(group_size)] - if left_rank is not None: - input_list[left_rank] = left_send - if right_rank is not None: - input_list[right_rank] = right_send - output_list = [torch.empty_like(input_i, device=input_.device) for input_i in input_list] - dist.all_to_all(output_list, input_list, group=mgroup) - - if left_rank is not None: - left_halo = output_list[left_rank] - if right_rank is not None: - right_halo = output_list[right_rank] - case _: - raise ValueError(f"Unknown halo communication strategy {os.environ['HALO_COMM']}") - - if bwd: - # remove gradient contribution from send regions and add halo regions + left_rank = group_rank - 1 if group_rank > 0 else None + right_rank = group_rank + 1 if group_rank < group_size - 1 else None + + input_list = [torch.empty(0, device=input_.device) for _ in range(group_size)] + if left_rank is not None: + input_list[left_rank] = left_send + if right_rank is not None: + input_list[right_rank] = right_send + output_list = [torch.empty_like(input_i, device=input_.device) for input_i in input_list] + + dist.all_to_all(output_list, input_list, group=mgroup) + + if bwd: # add gradient contributions to halo regions and zero out send regions if left_rank is not None: input_[:, left_send_slice, :] = 0 - input_[:, left_halo_slice, :] += left_halo + input_[:, left_halo_slice, :] += output_list[left_rank] if right_rank is not None: input_[:, right_send_slice, :] = 0 - input_[:, right_halo_slice, :] += right_halo - else: + input_[:, right_halo_slice, :] += output_list[right_rank] + else: # add halo regions to input tensor if left_rank is not None: - input_[:, left_halo_slice, :] = left_halo + input_[:, left_halo_slice, :] = output_list[left_rank] if right_rank is not None: - input_[:, right_halo_slice, :] = right_halo + input_[:, right_halo_slice, :] = output_list[right_rank] return input_ @@ -251,7 +199,7 @@ def shard_sequence(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor return _SplitSequenceParallelSection.apply(input_, shapes, mgroup) -def halo_exchange(x: Tensor, halo_size: int, mgroup: ProcessGroup) -> None: +def halo_exchange(x: Tensor, halo_size: int, mgroup: ProcessGroup) -> Tensor: """Exchange halo regions between ranks, Parameters @@ -262,10 +210,18 @@ def halo_exchange(x: Tensor, halo_size: int, mgroup: ProcessGroup) -> None: Halo size (left, right) mgroup : ProcessGroup Model communication group + + Returns + ------- + Tensor, int, int + Tensor appended with halo regions from neighboring ranks, left halo size, right halo size """ + if mgroup is None or dist.get_world_size(mgroup) == 1: + return x, 0, 0 + # pad tensor with halo regions - halo_size_left = halo_size if (mgroup and dist.get_rank(mgroup) != 0) else 0 - halo_size_right = halo_size if (mgroup and dist.get_rank(mgroup) != dist.get_world_size(mgroup) - 1) else 0 + halo_size_left = halo_size if dist.get_rank(mgroup) != 0 else 0 + halo_size_right = halo_size if dist.get_rank(mgroup) != dist.get_world_size(mgroup) - 1 else 0 x_pad = torch.nn.functional.pad(x, pad=(0, 0, halo_size_left, halo_size_right), mode="constant", value=0) out = _HaloExchange.apply(x_pad, halo_size, mgroup) @@ -333,7 +289,6 @@ def forward(ctx, input_, halo_size_, mgroup_): def backward(ctx, grad_output): if ctx.mgroup: return ( - # not sure if this works yet, need to test _halo_exchange(grad_output, ctx.halo_size, ctx.mgroup, bwd=True), None, None, diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index fec41260..0c35ad30 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -69,6 +69,9 @@ def __init__( if shard_strategy not in ["shard_heads", "shard_sequence"]: raise ValueError(f"Invalid shard_strategy: {shard_strategy}") + if shard_strategy == "shard_sequence": # remove this after PR #47 is merged (sliding window support) + assert _FLASH_ATTENTION_AVAILABLE, "Flash attention is required for shard_sequence strategy" + self.projection = nn.Linear(embed_dim, embed_dim, bias=True) def forward( @@ -80,7 +83,6 @@ def forward( ), "Only batch size of 1 is supported when model is sharded accross GPUs" if self.shard_strategy == "shard_sequence": - assert _FLASH_ATTENTION_AVAILABLE, "Flash attention is required for shard_sequence strategy" assert ( shapes[-1][0] // 2 >= self.window_size[0] ), "Sharded sequence length must be at least twice the window size" @@ -97,10 +99,9 @@ def forward( x_bgc, halo_size=self.window_size[0], mgroup=model_comm_group ) - # compute q, k, v (on local sequence shards) + # compute q, k, v (on local sequence shards with halos) query, key, value = self.lin_qkv(x_plus_halos).chunk(3, -1) - # further unpack feature dimension query, key, value = ( einops.rearrange( t, From 943a0d4ed5e935ce365018909484e1aa8b57821a Mon Sep 17 00:00:00 2001 From: japols Date: Tue, 17 Dec 2024 18:31:02 +0000 Subject: [PATCH 5/5] docs: changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e071b770..125679af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Keep it human-readable, your future self will thank you! - Mask NaN values in training loss function [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271) - New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64) - Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69) +- Add sequence sharding strategy for TransformerProcessor [#90](https://github.com/ecmwf/anemoi-models/pull/90) ### Changed - Bugfixes for CI