Skip to content

Commit

Permalink
minor cleanup to avoid one explicit list comprehension in indexed matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
tvercaut committed Sep 27, 2024
1 parent d785ada commit 2d1b38a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchsparsegradutils/indexed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def segment_mm(a, b, seglen_a):

# Ideally the conversions below to nested tensor would be handled natively
nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))
nested_b = torch.nested.as_nested_tensor(list(map(torch.squeeze, torch.split(b, 1, dim=0))))
nested_b = torch.nested.as_nested_tensor(torch.split(b, 1, dim=0)).reshape((R,D1,D2))

# The actual gather matmul computation
nested_ab = torch.matmul(nested_a, nested_b)
Expand Down Expand Up @@ -105,7 +105,7 @@ def gather_mm(a, b, idx_b):
# Ideally the conversions below to nested tensor would be handled without for looops and without copy
nested_a = torch.nested.as_nested_tensor([a[idx_b == i, :] for i in range(R)])
src_idx_reshuffled = torch.cat([src_idx[idx_b == i] for i in range(R)])
nested_b = torch.nested.as_nested_tensor([b[i, :, :].squeeze() for i in range(R)])
nested_b = torch.nested.as_nested_tensor(torch.split(b, 1, dim=0)).reshape((R,D1,D2))

# The actual gather matmul computation
nested_ab = torch.matmul(nested_a, nested_b)
Expand Down

0 comments on commit 2d1b38a

Please sign in to comment.