From 24de7620cf31aae38a181bc85922c22ea3e855c7 Mon Sep 17 00:00:00 2001 From: Tom Vercauteren Date: Fri, 27 Sep 2024 12:04:41 +0100 Subject: [PATCH] black --- torchsparsegradutils/indexed_matmul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchsparsegradutils/indexed_matmul.py b/torchsparsegradutils/indexed_matmul.py index f9d54e6..ef0c8f7 100644 --- a/torchsparsegradutils/indexed_matmul.py +++ b/torchsparsegradutils/indexed_matmul.py @@ -49,7 +49,7 @@ def segment_mm(a, b, seglen_a): # Ideally the conversions below to nested tensor would be handled natively nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0)) - nested_b = torch.nested.as_nested_tensor(torch.split(b, 1, dim=0)).reshape((R,D1,D2)) + nested_b = torch.nested.as_nested_tensor(torch.split(b, 1, dim=0)).reshape((R, D1, D2)) # The actual gather matmul computation nested_ab = torch.matmul(nested_a, nested_b) @@ -105,7 +105,7 @@ def gather_mm(a, b, idx_b): # Ideally the conversions below to nested tensor would be handled without for looops and without copy nested_a = torch.nested.as_nested_tensor([a[idx_b == i, :] for i in range(R)]) src_idx_reshuffled = torch.cat([src_idx[idx_b == i] for i in range(R)]) - nested_b = torch.nested.as_nested_tensor(torch.split(b, 1, dim=0)).reshape((R,D1,D2)) + nested_b = torch.nested.as_nested_tensor(torch.split(b, 1, dim=0)).reshape((R, D1, D2)) # The actual gather matmul computation nested_ab = torch.matmul(nested_a, nested_b)