diff --git a/models/src/anemoi/models/layers/block.py b/models/src/anemoi/models/layers/block.py index 72e487d2..0c417d71 100644 --- a/models/src/anemoi/models/layers/block.py +++ b/models/src/anemoi/models/layers/block.py @@ -13,6 +13,7 @@ from abc import ABC from abc import abstractmethod from typing import Optional +from typing import Union import einops import torch @@ -413,6 +414,51 @@ def shard_output_seq( return out + def attention_block( + self, + query: Tensor, + key: Tensor, + value: Tensor, + edges: Tensor, + edge_index: Adj, + shapes: tuple, + batch_size: int, + size: Union[int, tuple[int, int]], + num_chunks: int, + model_comm_group: Optional[ProcessGroup] = None, + ) -> Tensor: + if model_comm_group is not None: + assert ( + model_comm_group.size() == 1 or batch_size == 1 + ), "Only batch size of 1 is supported when model is sharded across GPUs" + + query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group) + + conv_size = size if isinstance(size, tuple) else None + + if num_chunks > 1: + # split 1-hop edges into chunks, compute self.conv chunk-wise and aggregate + edge_attr_list, edge_index_list = sort_edges_1hop_chunks( + num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks + ) + # shape: (num_nodes, num_heads, out_channels_conv) + out = torch.zeros((*query.shape[:2], self.out_channels_conv), device=query.device) + for i in range(num_chunks): + out += self.conv( + query=query, + key=key, + value=value, + edge_attr=edge_attr_list[i], + edge_index=edge_index_list[i], + size=conv_size, + ) + else: + out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=conv_size) + + out = self.shard_output_seq(out, shapes, batch_size, model_comm_group) + + return out + @abstractmethod def forward( self, @@ -421,8 +467,8 @@ def forward( edge_index: Adj, shapes: tuple, batch_size: int, + size: Union[int, tuple[int, int]], model_comm_group: Optional[ProcessGroup] = None, - size: Optional[Size] = None, ): ... @@ -483,8 +529,8 @@ def forward( edge_index: Adj, shapes: tuple, batch_size: int, + size: tuple[int, int], model_comm_group: Optional[ProcessGroup] = None, - size: Optional[Size] = None, ): x_skip = x @@ -498,34 +544,20 @@ def forward( value = self.lin_value(x[0]) edges = self.lin_edge(edge_attr) - if model_comm_group is not None: - assert ( - model_comm_group.size() == 1 or batch_size == 1 - ), "Only batch size of 1 is supported when model is sharded across GPUs" - - query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group) - num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE - if num_chunks > 1: - # split 1-hop edges into chunks, compute self.conv chunk-wise and aggregate - edge_attr_list, edge_index_list = sort_edges_1hop_chunks( - num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks - ) - out = torch.zeros((x[1].shape[0], self.num_heads, self.out_channels_conv), device=x[1].device) - for i in range(num_chunks): - out += self.conv( - query=query, - key=key, - value=value, - edge_attr=edge_attr_list[i], - edge_index=edge_index_list[i], - size=size, - ) - else: - out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) - - out = self.shard_output_seq(out, shapes, batch_size, model_comm_group) + out = self.attention_block( + query=query, + key=key, + value=value, + edges=edges, + edge_index=edge_index, + shapes=shapes, + batch_size=batch_size, + size=size, + num_chunks=num_chunks, + model_comm_group=model_comm_group, + ) # compute out = self.projection(out + x_r) in chunks: out = torch.cat([self.projection(chunk) for chunk in torch.tensor_split(out + x_r, num_chunks, dim=0)], dim=0) @@ -606,8 +638,8 @@ def forward( edge_index: Adj, shapes: tuple, batch_size: int, + size: int, model_comm_group: Optional[ProcessGroup] = None, - size: Optional[Size] = None, ): x_skip = x @@ -619,17 +651,29 @@ def forward( edges = self.lin_edge(edge_attr) - if model_comm_group is not None: - assert ( - model_comm_group.size() == 1 or batch_size == 1 - ), "Only batch size of 1 is supported when model is sharded across GPUs" + num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE - query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group) - out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) - out = self.shard_output_seq(out, shapes, batch_size, model_comm_group) - out = self.projection(out + x_r) + out = self.attention_block( + query=query, + key=key, + value=value, + edges=edges, + edge_index=edge_index, + shapes=shapes, + batch_size=batch_size, + size=size, + num_chunks=num_chunks, + model_comm_group=model_comm_group, + ) + + # compute out = self.projection(out + x_r) in chunks: + out = torch.cat([self.projection(chunk) for chunk in torch.tensor_split(out + x_r, num_chunks, dim=0)], dim=0) out = out + x_skip nodes_new = self.node_dst_mlp(out) + out + # compute nodes_new = self.node_dst_mlp(out) + out in chunks: + nodes_new = torch.cat( + [self.node_dst_mlp(chunk) + chunk for chunk in out.tensor_split(num_chunks, dim=0)], dim=0 + ) return nodes_new, edge_attr diff --git a/models/src/anemoi/models/layers/chunk.py b/models/src/anemoi/models/layers/chunk.py index 5c4fae38..a4bb897f 100644 --- a/models/src/anemoi/models/layers/chunk.py +++ b/models/src/anemoi/models/layers/chunk.py @@ -233,6 +233,6 @@ def forward( size: Optional[Size] = None, ) -> OptPairTensor: for i in range(self.num_layers): - x, edge_attr = self.blocks[i](x, edge_attr, edge_index, shapes, batch_size, model_comm_group, size=size) + x, edge_attr = self.blocks[i](x, edge_attr, edge_index, shapes, batch_size, size, model_comm_group) return x, edge_attr diff --git a/models/src/anemoi/models/layers/mapper.py b/models/src/anemoi/models/layers/mapper.py index 1ae45031..bf5839c2 100644 --- a/models/src/anemoi/models/layers/mapper.py +++ b/models/src/anemoi/models/layers/mapper.py @@ -258,13 +258,13 @@ def forward( x_src, x_dst, shapes_src, shapes_dst = self.pre_process(x, shard_shapes, model_comm_group) (x_src, x_dst), edge_attr = self.proc( - (x_src, x_dst), - edge_attr, - edge_index, - (shapes_src, shapes_dst, shapes_edge_attr), - batch_size, - model_comm_group, + x=(x_src, x_dst), + edge_attr=edge_attr, + edge_index=edge_index, + shapes=(shapes_src, shapes_dst, shapes_edge_attr), + batch_size=batch_size, size=size, + model_comm_group=model_comm_group, ) x_dst = self.post_process(x_dst, shapes_dst, model_comm_group) diff --git a/models/src/anemoi/models/layers/processor.py b/models/src/anemoi/models/layers/processor.py index 8dba1f66..1e023bda 100644 --- a/models/src/anemoi/models/layers/processor.py +++ b/models/src/anemoi/models/layers/processor.py @@ -323,6 +323,7 @@ def forward( *args, **kwargs, ) -> Tensor: + size = sum(x[0] for x in shard_shapes) shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels) edge_attr = self.trainable(self.edge_attr, batch_size) @@ -333,11 +334,12 @@ def forward( edge_attr = shard_tensor(edge_attr, 0, shapes_edge_attr, model_comm_group) x, edge_attr = self.run_layers( - (x, edge_attr), - edge_index, - (shape_nodes, shape_nodes, shapes_edge_attr), - batch_size, - model_comm_group, + data=(x, edge_attr), + edge_index=edge_index, + shapes=(shape_nodes, shape_nodes, shapes_edge_attr), + batch_size=batch_size, + size=size, + model_comm_group=model_comm_group, ) return x diff --git a/models/tests/layers/block/test_block_graphtransformer.py b/models/tests/layers/block/test_block_graphtransformer.py index a6162046..d0967562 100644 --- a/models/tests/layers/block/test_block_graphtransformer.py +++ b/models/tests/layers/block/test_block_graphtransformer.py @@ -187,6 +187,45 @@ def test_GraphTransformerProcessorBlock_forward_backward(init, block): ), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}" +@pytest.fixture +def test_GraphTransformerProcessorBlock_chunking(init, block, monkeypatch): + ( + in_channels, + _hidden_dim, + _out_channels, + edge_dim, + _bias, + _activation, + _num_heads, + _num_chunks, + ) = init + # Initialize GraphTransformerProcessorBlock + block = block + + # Generate random input tensor + x = torch.randn((10, in_channels)) + edge_attr = torch.randn((10, edge_dim)) + edge_index = torch.randint(1, 10, (2, 10)) + shapes = (10, 10, 10) + batch_size = 1 + num_chunks = torch.randint(2, 10, (1,)).item() + + # manually set to non-training mode + block.eval() + + # result with chunks + monkeypatch.setenv("ANEMOI_INFERENCE_NUM_CHUNKS", str(num_chunks)) + importlib.reload(anemoi.models.layers.block) + out_chunked, _ = block(x, edge_attr, edge_index, shapes, batch_size) + # result without chunks, reload block for new env variable + monkeypatch.setenv("ANEMOI_INFERENCE_NUM_CHUNKS", "1") + importlib.reload(anemoi.models.layers.block) + out, _ = block(x, edge_attr, edge_index, shapes, batch_size) + + assert out.shape == out_chunked.shape, f"out.shape ({out.shape}) != out_chunked.shape ({out_chunked.shape})" + assert torch.allclose(out, out_chunked, atol=1e-4), "out != out_chunked" + + @pytest.fixture def mapper_block(init): (