Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models: GraphtransformerProcessor chunking #66

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 81 additions & 37 deletions models/src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import ABC
from abc import abstractmethod
from typing import Optional
from typing import Union

import einops
import torch
Expand Down Expand Up @@ -413,6 +414,51 @@ def shard_output_seq(

return out

def attention_block(
self,
query: Tensor,
key: Tensor,
value: Tensor,
edges: Tensor,
edge_index: Adj,
shapes: tuple,
batch_size: int,
size: Union[int, tuple[int, int]],
num_chunks: int,
model_comm_group: Optional[ProcessGroup] = None,
) -> Tensor:
if model_comm_group is not None:
assert (
model_comm_group.size() == 1 or batch_size == 1
), "Only batch size of 1 is supported when model is sharded across GPUs"

query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group)

conv_size = size if isinstance(size, tuple) else None

if num_chunks > 1:
# split 1-hop edges into chunks, compute self.conv chunk-wise and aggregate
edge_attr_list, edge_index_list = sort_edges_1hop_chunks(
num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks
)
# shape: (num_nodes, num_heads, out_channels_conv)
out = torch.zeros((*query.shape[:2], self.out_channels_conv), device=query.device)
for i in range(num_chunks):
out += self.conv(
query=query,
key=key,
value=value,
edge_attr=edge_attr_list[i],
edge_index=edge_index_list[i],
size=conv_size,
)
else:
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=conv_size)

out = self.shard_output_seq(out, shapes, batch_size, model_comm_group)

return out

@abstractmethod
def forward(
self,
Expand All @@ -421,8 +467,8 @@ def forward(
edge_index: Adj,
shapes: tuple,
batch_size: int,
size: Union[int, tuple[int, int]],
model_comm_group: Optional[ProcessGroup] = None,
size: Optional[Size] = None,
): ...


Expand Down Expand Up @@ -483,8 +529,8 @@ def forward(
edge_index: Adj,
shapes: tuple,
batch_size: int,
size: tuple[int, int],
model_comm_group: Optional[ProcessGroup] = None,
size: Optional[Size] = None,
):
x_skip = x

Expand All @@ -498,34 +544,20 @@ def forward(
value = self.lin_value(x[0])
edges = self.lin_edge(edge_attr)

if model_comm_group is not None:
assert (
model_comm_group.size() == 1 or batch_size == 1
), "Only batch size of 1 is supported when model is sharded across GPUs"

query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group)

num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE

if num_chunks > 1:
# split 1-hop edges into chunks, compute self.conv chunk-wise and aggregate
edge_attr_list, edge_index_list = sort_edges_1hop_chunks(
num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks
)
out = torch.zeros((x[1].shape[0], self.num_heads, self.out_channels_conv), device=x[1].device)
for i in range(num_chunks):
out += self.conv(
query=query,
key=key,
value=value,
edge_attr=edge_attr_list[i],
edge_index=edge_index_list[i],
size=size,
)
else:
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)

out = self.shard_output_seq(out, shapes, batch_size, model_comm_group)
out = self.attention_block(
query=query,
key=key,
value=value,
edges=edges,
edge_index=edge_index,
shapes=shapes,
batch_size=batch_size,
size=size,
num_chunks=num_chunks,
model_comm_group=model_comm_group,
)

# compute out = self.projection(out + x_r) in chunks:
out = torch.cat([self.projection(chunk) for chunk in torch.tensor_split(out + x_r, num_chunks, dim=0)], dim=0)
Expand Down Expand Up @@ -606,8 +638,8 @@ def forward(
edge_index: Adj,
shapes: tuple,
batch_size: int,
size: int,
model_comm_group: Optional[ProcessGroup] = None,
size: Optional[Size] = None,
):
x_skip = x

Expand All @@ -619,17 +651,29 @@ def forward(

edges = self.lin_edge(edge_attr)

if model_comm_group is not None:
assert (
model_comm_group.size() == 1 or batch_size == 1
), "Only batch size of 1 is supported when model is sharded across GPUs"
num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE

query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group)
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)
out = self.shard_output_seq(out, shapes, batch_size, model_comm_group)
out = self.projection(out + x_r)
out = self.attention_block(
query=query,
key=key,
value=value,
edges=edges,
edge_index=edge_index,
shapes=shapes,
batch_size=batch_size,
size=size,
num_chunks=num_chunks,
model_comm_group=model_comm_group,
)

# compute out = self.projection(out + x_r) in chunks:
out = torch.cat([self.projection(chunk) for chunk in torch.tensor_split(out + x_r, num_chunks, dim=0)], dim=0)

out = out + x_skip
nodes_new = self.node_dst_mlp(out) + out
# compute nodes_new = self.node_dst_mlp(out) + out in chunks:
nodes_new = torch.cat(
[self.node_dst_mlp(chunk) + chunk for chunk in out.tensor_split(num_chunks, dim=0)], dim=0
)

return nodes_new, edge_attr
2 changes: 1 addition & 1 deletion models/src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,6 @@ def forward(
size: Optional[Size] = None,
) -> OptPairTensor:
for i in range(self.num_layers):
x, edge_attr = self.blocks[i](x, edge_attr, edge_index, shapes, batch_size, model_comm_group, size=size)
x, edge_attr = self.blocks[i](x, edge_attr, edge_index, shapes, batch_size, size, model_comm_group)

return x, edge_attr
12 changes: 6 additions & 6 deletions models/src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def forward(
x_src, x_dst, shapes_src, shapes_dst = self.pre_process(x, shard_shapes, model_comm_group)

(x_src, x_dst), edge_attr = self.proc(
(x_src, x_dst),
edge_attr,
edge_index,
(shapes_src, shapes_dst, shapes_edge_attr),
batch_size,
model_comm_group,
x=(x_src, x_dst),
edge_attr=edge_attr,
edge_index=edge_index,
shapes=(shapes_src, shapes_dst, shapes_edge_attr),
batch_size=batch_size,
size=size,
model_comm_group=model_comm_group,
)

x_dst = self.post_process(x_dst, shapes_dst, model_comm_group)
Expand Down
12 changes: 7 additions & 5 deletions models/src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def forward(
*args,
**kwargs,
) -> Tensor:
size = sum(x[0] for x in shard_shapes)

shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels)
edge_attr = self.trainable(self.edge_attr, batch_size)
Expand All @@ -333,11 +334,12 @@ def forward(
edge_attr = shard_tensor(edge_attr, 0, shapes_edge_attr, model_comm_group)

x, edge_attr = self.run_layers(
(x, edge_attr),
edge_index,
(shape_nodes, shape_nodes, shapes_edge_attr),
batch_size,
model_comm_group,
data=(x, edge_attr),
edge_index=edge_index,
shapes=(shape_nodes, shape_nodes, shapes_edge_attr),
batch_size=batch_size,
size=size,
model_comm_group=model_comm_group,
)

return x
39 changes: 39 additions & 0 deletions models/tests/layers/block/test_block_graphtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,45 @@ def test_GraphTransformerProcessorBlock_forward_backward(init, block):
), f"param.grad.shape ({param.grad.shape}) != param.shape ({param.shape}) for {param}"


@pytest.fixture
def test_GraphTransformerProcessorBlock_chunking(init, block, monkeypatch):
(
in_channels,
_hidden_dim,
_out_channels,
edge_dim,
_bias,
_activation,
_num_heads,
_num_chunks,
) = init
# Initialize GraphTransformerProcessorBlock
block = block

# Generate random input tensor
x = torch.randn((10, in_channels))
edge_attr = torch.randn((10, edge_dim))
edge_index = torch.randint(1, 10, (2, 10))
shapes = (10, 10, 10)
batch_size = 1
num_chunks = torch.randint(2, 10, (1,)).item()

# manually set to non-training mode
block.eval()

# result with chunks
monkeypatch.setenv("ANEMOI_INFERENCE_NUM_CHUNKS", str(num_chunks))
importlib.reload(anemoi.models.layers.block)
out_chunked, _ = block(x, edge_attr, edge_index, shapes, batch_size)
# result without chunks, reload block for new env variable
monkeypatch.setenv("ANEMOI_INFERENCE_NUM_CHUNKS", "1")
importlib.reload(anemoi.models.layers.block)
out, _ = block(x, edge_attr, edge_index, shapes, batch_size)

assert out.shape == out_chunked.shape, f"out.shape ({out.shape}) != out_chunked.shape ({out_chunked.shape})"
assert torch.allclose(out, out_chunked, atol=1e-4), "out != out_chunked"


@pytest.fixture
def mapper_block(init):
(
Expand Down
Loading