Skip to content

Commit

Permalink
Added simple workarounds for gather_mm and segment_mm (#57)
Browse files Browse the repository at this point in the history
* Added simple workarounds for gather_mm and segment_mm. See #56

* bumping python and pytorch version in CI

* enabling black on notebooks in CI

* updating github actions to avoid deprecation warning
  • Loading branch information
tvercaut authored Sep 27, 2024
1 parent 0f92297 commit d785ada
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 8 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: psf/black@stable
with:
jupyter: true
17 changes: 12 additions & 5 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,28 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
torch-version: ["1.13.1", "2.0.1"]
python-version: ["3.8", "3.10", "3.12"]
torch-version: ["1.13.1", "2.4.1"]
exclude:
- python-version: "3.12"
torch-version: "1.13.1"

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch==${{ matrix.torch-version }}
python -m pip install flake8 black
python -m pip install flake8 black[jupyter]
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
- name: numpy downgrade for pytorch 1.x
if: startsWith(matrix.torch-version, '1.')
run: |
pip install "numpy<2"
- name: Lint check with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ def readme():
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.12",
],
python_requires=">=3.8, <3.11",
python_requires=">=3.8",
keywords="sparse torch utility",
url="https://github.com/cai4cai/torchsparsegradutils",
author="CAI4CAI research group",
Expand Down
10 changes: 9 additions & 1 deletion torchsparsegradutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .sparse_matmul import sparse_mm
from .indexed_matmul import gather_mm, segment_mm
from .sparse_solve import sparse_triangular_solve, sparse_generic_solve
from .sparse_lstsq import sparse_generic_lstsq

__all__ = ["sparse_mm", "sparse_triangular_solve", "sparse_generic_solve", "sparse_generic_lstsq"]
__all__ = [
"sparse_mm",
"gather_mm",
"segment_mm",
"sparse_triangular_solve",
"sparse_generic_solve",
"sparse_generic_lstsq",
]
117 changes: 117 additions & 0 deletions torchsparsegradutils/indexed_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch

try:
import dgl.ops as dglops

dgl_installed = True
except ImportError:
dgl_installed = False


def segment_mm(a, b, seglen_a):
"""
Performs matrix multiplication according to segments.
See https://docs.dgl.ai/generated/dgl.ops.segment_mm.html
Suppose ``seglen_a == [10, 5, 0, 3]``, the operator will perform
four matrix multiplications::
a[0:10] @ b[0], a[10:15] @ b[1],
a[15:15] @ b[2], a[15:18] @ b[3]
Args:
a (torch.Tensor): The left operand, 2-D tensor of shape ``(N, D1)``
b (torch.Tensor): The right operand, 3-D tensor of shape ``(R, D1, D2)``
seglen_a (torch.Tensor): An integer tensor of shape ``(R,)``. Each element is the length of segments of input ``a``. The summation of all elements must be equal to ``N``.
Returns:
torch.Tensor: The output dense matrix of shape ``(N, D2)``
"""
if torch.__version__ < (2, 4):
raise NotImplementedError("PyTorch version is too old for nested tesors")

if dgl_installed:
# DGL is probably more computationally efficient
# See https://github.com/pytorch/pytorch/issues/136747
return dglops.segment_mm(a, b, seglen_a)

if not a.dim() == 2 or not b.dim() == 3 or not seglen_a.dim() == 1:
raise ValueError("Input tensors have unexpected dimensions")

N, _ = a.shape
R, D1, D2 = b.shape

# Sanity check sizes
if not a.shape[1] == D1 or not seglen_a.shape[0] == R:
raise ValueError("Incompatible size for inputs")

segidx_a = torch.cumsum(seglen_a[:-1], dim=0)

# Ideally the conversions below to nested tensor would be handled natively
nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))
nested_b = torch.nested.as_nested_tensor(list(map(torch.squeeze, torch.split(b, 1, dim=0))))

# The actual gather matmul computation
nested_ab = torch.matmul(nested_a, nested_b)

# Convert back to tensors, again ideally this would be handled natively
ab = torch.cat(nested_ab.unbind(), dim=0)
return ab


def gather_mm(a, b, idx_b):
"""
Gather data according to the given indices and perform matrix multiplication.
See https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
Let the result tensor be ``c``, the operator conducts the following computation:
c[i] = a[i] @ b[idx_b[i]]
, where len(c) == len(idx_b)
Args:
a (torch.Tensor): A 2-D tensor of shape ``(N, D1)``
b (torch.Tensor): A 3-D tensor of shape ``(R, D1, D2)``
idx_b (torch.Tensor): An 1-D integer tensor of shape ``(N,)``.
Returns:
torch.Tensor: The output dense matrix of shape ``(N, D2)``
"""
if torch.__version__ < (2, 4):
raise NotImplementedError("PyTorch version is too old for nested tesors")

if dgl_installed:
# DGL is more computationally efficient
# See https://github.com/pytorch/pytorch/issues/136747
return dglops.gather_mm(a, b, idx_b)

# Dependency free fallback
if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor) or not isinstance(idx_b, torch.Tensor):
raise ValueError("Inputs should be instances of torch.Tensor")

if not a.dim() == 2 or not b.dim() == 3 or not idx_b.dim() == 1:
raise ValueError("Input tensors have unexpected dimensions")

N = idx_b.shape[0]
R, D1, D2 = b.shape

# Sanity check sizes
if not a.shape[0] == N or not a.shape[1] == D1:
raise ValueError("Incompatible size for inputs")

torchdevice = a.device
src_idx = torch.arange(N, device=torchdevice)

# Ideally the conversions below to nested tensor would be handled without for looops and without copy
nested_a = torch.nested.as_nested_tensor([a[idx_b == i, :] for i in range(R)])
src_idx_reshuffled = torch.cat([src_idx[idx_b == i] for i in range(R)])
nested_b = torch.nested.as_nested_tensor([b[i, :, :].squeeze() for i in range(R)])

# The actual gather matmul computation
nested_ab = torch.matmul(nested_a, nested_b)

# Convert back to tensors, again, ideally this would be handled natively with no copy
ab_segmented = torch.cat(nested_ab.unbind(), dim=0)
ab = torch.empty((N, D2), device=torchdevice)
ab[src_idx_reshuffled] = ab_segmented
return ab
94 changes: 94 additions & 0 deletions torchsparsegradutils/tests/test_indexed_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import pytest

if torch.__version__ < (2, 4):
pytest.skip(
"Skipping test based on nested tensors since an old version of pytorch is used", allow_module_level=True
)

from torchsparsegradutils import gather_mm, segment_mm

# Identify Testing Parameters
DEVICES = [torch.device("cpu")]
if torch.cuda.is_available():
DEVICES.append(torch.device("cuda"))

TEST_DATA = [
# name N, R, D1, D2
("small", 100, 32, 7, 10),
]

INDEX_DTYPES = [torch.int32, torch.int64]
VALUE_DTYPES = [torch.float32, torch.float64]

ATOL = 1e-6 # relaxed tolerance to allow for float32
RTOL = 1e-4


# Define Test Names:
def data_id(shapes):
return shapes[0]


def device_id(device):
return str(device)


def dtype_id(dtype):
return str(dtype).split(".")[-1]


# Define Fixtures


@pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA])
def shapes(request):
return request.param


@pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES])
def value_dtype(request):
return request.param


@pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES])
def index_dtype(request):
return request.param


@pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES])
def device(request):
return request.param


# Define Tests


def test_segment_mm(device, value_dtype, index_dtype, shapes):
_, N, R, D1, D2 = shapes

a = torch.randn((N, D1), device=device)
b = torch.randn((R, D1, D2), device=device)
seglen_a = torch.randint(low=1, high=int(N / R), size=(R,), device=device)
seglen_a[-1] = N - seglen_a[:-1].sum()

ab = segment_mm(a, b, seglen_a)

k = 0
for i in range(R):
for j in range(seglen_a[i]):
assert torch.allclose(ab[k, :].squeeze(), a[k, :].squeeze() @ b[i, :, :].squeeze(), atol=ATOL, rtol=RTOL)
k += 1


def test_gather_mm(device, value_dtype, index_dtype, shapes):
_, N, R, D1, D2 = shapes

a = torch.randn((N, D1), device=device)
b = torch.randn((R, D1, D2), device=device)
idx_b = torch.randint(low=0, high=R, size=(N,), device=device)

ab = gather_mm(a, b, idx_b)

for i in range(N):
assert torch.allclose(ab[i, :].squeeze(), a[i, :].squeeze() @ b[idx_b[i], :, :].squeeze(), atol=ATOL, rtol=RTOL)

0 comments on commit d785ada

Please sign in to comment.