From d785adab4a62ab2bd91d60e5ae9181bf8cc97a32 Mon Sep 17 00:00:00 2001 From: Tom Vercauteren Date: Fri, 27 Sep 2024 09:21:36 +0100 Subject: [PATCH] Added simple workarounds for gather_mm and segment_mm (#57) * Added simple workarounds for gather_mm and segment_mm. See #56 * bumping python and pytorch version in CI * enabling black on notebooks in CI * updating github actions to avoid deprecation warning --- .github/workflows/black.yml | 4 +- .github/workflows/python-package.yml | 17 ++- setup.py | 3 +- torchsparsegradutils/__init__.py | 10 +- torchsparsegradutils/indexed_matmul.py | 117 ++++++++++++++++++ .../tests/test_indexed_matmul.py | 94 ++++++++++++++ 6 files changed, 237 insertions(+), 8 deletions(-) create mode 100644 torchsparsegradutils/indexed_matmul.py create mode 100644 torchsparsegradutils/tests/test_indexed_matmul.py diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index 9065b5e..bfc6712 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -6,5 +6,7 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: psf/black@stable + with: + jupyter: true diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 6164b04..6f0492e 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -12,21 +12,28 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] - torch-version: ["1.13.1", "2.0.1"] + python-version: ["3.8", "3.10", "3.12"] + torch-version: ["1.13.1", "2.4.1"] + exclude: + - python-version: "3.12" + torch-version: "1.13.1" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install torch==${{ matrix.torch-version }} - python -m pip install flake8 black + python -m pip install flake8 black[jupyter] if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi + - name: numpy downgrade for pytorch 1.x + if: startsWith(matrix.torch-version, '1.') + run: | + pip install "numpy<2" - name: Lint check with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/setup.py b/setup.py index 7459a4d..3257e06 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,9 @@ def readme(): "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.12", ], - python_requires=">=3.8, <3.11", + python_requires=">=3.8", keywords="sparse torch utility", url="https://github.com/cai4cai/torchsparsegradutils", author="CAI4CAI research group", diff --git a/torchsparsegradutils/__init__.py b/torchsparsegradutils/__init__.py index f466352..5dd5b34 100644 --- a/torchsparsegradutils/__init__.py +++ b/torchsparsegradutils/__init__.py @@ -1,5 +1,13 @@ from .sparse_matmul import sparse_mm +from .indexed_matmul import gather_mm, segment_mm from .sparse_solve import sparse_triangular_solve, sparse_generic_solve from .sparse_lstsq import sparse_generic_lstsq -__all__ = ["sparse_mm", "sparse_triangular_solve", "sparse_generic_solve", "sparse_generic_lstsq"] +__all__ = [ + "sparse_mm", + "gather_mm", + "segment_mm", + "sparse_triangular_solve", + "sparse_generic_solve", + "sparse_generic_lstsq", +] diff --git a/torchsparsegradutils/indexed_matmul.py b/torchsparsegradutils/indexed_matmul.py new file mode 100644 index 0000000..f6dc153 --- /dev/null +++ b/torchsparsegradutils/indexed_matmul.py @@ -0,0 +1,117 @@ +import torch + +try: + import dgl.ops as dglops + + dgl_installed = True +except ImportError: + dgl_installed = False + + +def segment_mm(a, b, seglen_a): + """ + Performs matrix multiplication according to segments. + See https://docs.dgl.ai/generated/dgl.ops.segment_mm.html + + Suppose ``seglen_a == [10, 5, 0, 3]``, the operator will perform + four matrix multiplications:: + + a[0:10] @ b[0], a[10:15] @ b[1], + a[15:15] @ b[2], a[15:18] @ b[3] + + Args: + a (torch.Tensor): The left operand, 2-D tensor of shape ``(N, D1)`` + b (torch.Tensor): The right operand, 3-D tensor of shape ``(R, D1, D2)`` + seglen_a (torch.Tensor): An integer tensor of shape ``(R,)``. Each element is the length of segments of input ``a``. The summation of all elements must be equal to ``N``. + + Returns: + torch.Tensor: The output dense matrix of shape ``(N, D2)`` + """ + if torch.__version__ < (2, 4): + raise NotImplementedError("PyTorch version is too old for nested tesors") + + if dgl_installed: + # DGL is probably more computationally efficient + # See https://github.com/pytorch/pytorch/issues/136747 + return dglops.segment_mm(a, b, seglen_a) + + if not a.dim() == 2 or not b.dim() == 3 or not seglen_a.dim() == 1: + raise ValueError("Input tensors have unexpected dimensions") + + N, _ = a.shape + R, D1, D2 = b.shape + + # Sanity check sizes + if not a.shape[1] == D1 or not seglen_a.shape[0] == R: + raise ValueError("Incompatible size for inputs") + + segidx_a = torch.cumsum(seglen_a[:-1], dim=0) + + # Ideally the conversions below to nested tensor would be handled natively + nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0)) + nested_b = torch.nested.as_nested_tensor(list(map(torch.squeeze, torch.split(b, 1, dim=0)))) + + # The actual gather matmul computation + nested_ab = torch.matmul(nested_a, nested_b) + + # Convert back to tensors, again ideally this would be handled natively + ab = torch.cat(nested_ab.unbind(), dim=0) + return ab + + +def gather_mm(a, b, idx_b): + """ + Gather data according to the given indices and perform matrix multiplication. + See https://docs.dgl.ai/generated/dgl.ops.gather_mm.html + + Let the result tensor be ``c``, the operator conducts the following computation: + + c[i] = a[i] @ b[idx_b[i]] + , where len(c) == len(idx_b) + + Args: + a (torch.Tensor): A 2-D tensor of shape ``(N, D1)`` + b (torch.Tensor): A 3-D tensor of shape ``(R, D1, D2)`` + idx_b (torch.Tensor): An 1-D integer tensor of shape ``(N,)``. + + Returns: + torch.Tensor: The output dense matrix of shape ``(N, D2)`` + """ + if torch.__version__ < (2, 4): + raise NotImplementedError("PyTorch version is too old for nested tesors") + + if dgl_installed: + # DGL is more computationally efficient + # See https://github.com/pytorch/pytorch/issues/136747 + return dglops.gather_mm(a, b, idx_b) + + # Dependency free fallback + if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor) or not isinstance(idx_b, torch.Tensor): + raise ValueError("Inputs should be instances of torch.Tensor") + + if not a.dim() == 2 or not b.dim() == 3 or not idx_b.dim() == 1: + raise ValueError("Input tensors have unexpected dimensions") + + N = idx_b.shape[0] + R, D1, D2 = b.shape + + # Sanity check sizes + if not a.shape[0] == N or not a.shape[1] == D1: + raise ValueError("Incompatible size for inputs") + + torchdevice = a.device + src_idx = torch.arange(N, device=torchdevice) + + # Ideally the conversions below to nested tensor would be handled without for looops and without copy + nested_a = torch.nested.as_nested_tensor([a[idx_b == i, :] for i in range(R)]) + src_idx_reshuffled = torch.cat([src_idx[idx_b == i] for i in range(R)]) + nested_b = torch.nested.as_nested_tensor([b[i, :, :].squeeze() for i in range(R)]) + + # The actual gather matmul computation + nested_ab = torch.matmul(nested_a, nested_b) + + # Convert back to tensors, again, ideally this would be handled natively with no copy + ab_segmented = torch.cat(nested_ab.unbind(), dim=0) + ab = torch.empty((N, D2), device=torchdevice) + ab[src_idx_reshuffled] = ab_segmented + return ab diff --git a/torchsparsegradutils/tests/test_indexed_matmul.py b/torchsparsegradutils/tests/test_indexed_matmul.py new file mode 100644 index 0000000..7c9a527 --- /dev/null +++ b/torchsparsegradutils/tests/test_indexed_matmul.py @@ -0,0 +1,94 @@ +import torch +import pytest + +if torch.__version__ < (2, 4): + pytest.skip( + "Skipping test based on nested tensors since an old version of pytorch is used", allow_module_level=True + ) + +from torchsparsegradutils import gather_mm, segment_mm + +# Identify Testing Parameters +DEVICES = [torch.device("cpu")] +if torch.cuda.is_available(): + DEVICES.append(torch.device("cuda")) + +TEST_DATA = [ + # name N, R, D1, D2 + ("small", 100, 32, 7, 10), +] + +INDEX_DTYPES = [torch.int32, torch.int64] +VALUE_DTYPES = [torch.float32, torch.float64] + +ATOL = 1e-6 # relaxed tolerance to allow for float32 +RTOL = 1e-4 + + +# Define Test Names: +def data_id(shapes): + return shapes[0] + + +def device_id(device): + return str(device) + + +def dtype_id(dtype): + return str(dtype).split(".")[-1] + + +# Define Fixtures + + +@pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA]) +def shapes(request): + return request.param + + +@pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES]) +def value_dtype(request): + return request.param + + +@pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES]) +def index_dtype(request): + return request.param + + +@pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES]) +def device(request): + return request.param + + +# Define Tests + + +def test_segment_mm(device, value_dtype, index_dtype, shapes): + _, N, R, D1, D2 = shapes + + a = torch.randn((N, D1), device=device) + b = torch.randn((R, D1, D2), device=device) + seglen_a = torch.randint(low=1, high=int(N / R), size=(R,), device=device) + seglen_a[-1] = N - seglen_a[:-1].sum() + + ab = segment_mm(a, b, seglen_a) + + k = 0 + for i in range(R): + for j in range(seglen_a[i]): + assert torch.allclose(ab[k, :].squeeze(), a[k, :].squeeze() @ b[i, :, :].squeeze(), atol=ATOL, rtol=RTOL) + k += 1 + + +def test_gather_mm(device, value_dtype, index_dtype, shapes): + _, N, R, D1, D2 = shapes + + a = torch.randn((N, D1), device=device) + b = torch.randn((R, D1, D2), device=device) + idx_b = torch.randint(low=0, high=R, size=(N,), device=device) + + ab = gather_mm(a, b, idx_b) + + for i in range(N): + assert torch.allclose(ab[i, :].squeeze(), a[i, :].squeeze() @ b[idx_b[i], :, :].squeeze(), atol=ATOL, rtol=RTOL)