-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[triton] Add clustering support and test
PiperOrigin-RevId: 612417957
- Loading branch information
The jax_triton Authors
committed
Mar 4, 2024
1 parent
e308f5d
commit 4a50e1d
Showing
2 changed files
with
90 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright 2024 The jax_triton Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""H100 clustering tests.""" | ||
|
||
import functools | ||
import math | ||
from unittest import mock | ||
|
||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
import jax | ||
import jax.numpy as jnp | ||
import jax_triton as jt | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
def _dummy_fn(x): | ||
assert x.size % 4 == 0 | ||
|
||
@triton.jit | ||
def dummy_kernel(x_ptr, o_ptr): | ||
offs = tl.program_id(axis=0) * 4 + tl.arange(0, 4) | ||
tl.store(o_ptr + offs, tl.load(x_ptr + offs)) | ||
|
||
return jt.triton_call(x, kernel=dummy_kernel, out_shape=x, grid=(x.size // 4)) | ||
|
||
|
||
class ClusterTest(parameterized.TestCase): | ||
|
||
@parameterized.parameters(1, 2, 3, 4, 8) | ||
def test_cluster(self, num_ctas): | ||
if 'h100' not in jax.devices()[0].device_kind.lower(): | ||
self.skipTest('Clusters available only on H100s.') | ||
|
||
cluster_dims = [] | ||
original_compile_ttir_to_ptx_fn = jt.triton_lib.compile_ttir_to_ptx_inplace | ||
|
||
def my_compile_ttir_to_ptx(*args, **kwargs): | ||
nonlocal cluster_dims, original_compile_ttir_to_ptx_fn | ||
ret_args = original_compile_ttir_to_ptx_fn(*args, **kwargs) | ||
cluster_dims = ret_args[-1] | ||
return ret_args | ||
|
||
my_triton_call = functools.partial(jt.triton_call, num_ctas=num_ctas) | ||
|
||
with mock.patch.object(jt, 'triton_call', my_triton_call): | ||
with mock.patch.object( | ||
jt.triton_lib, 'compile_ttir_to_ptx_inplace', my_compile_ttir_to_ptx | ||
): | ||
_dummy_fn(jnp.empty((16,))) | ||
self.assertEqual(math.prod(cluster_dims), num_ctas) | ||
|
||
def test_invalid_cluster_size(self): | ||
if 'h100' not in jax.devices()[0].device_kind.lower(): | ||
self.skipTest('Clusters available on H100s.') | ||
|
||
my_triton_call = functools.partial(jt.triton_call, num_ctas=9) | ||
|
||
with mock.patch.object(jt, 'triton_call', my_triton_call): | ||
with self.assertRaises(Exception): | ||
_dummy_fn(jnp.empty((16,))) | ||
|
||
def test_cluster_not_available(self): | ||
if 'h100' in jax.devices()[0].device_kind.lower(): | ||
self.skipTest('Clusters available only on H100s.') | ||
|
||
my_triton_call = functools.partial(jt.triton_call, num_ctas=2) | ||
|
||
with mock.patch.object(jt, 'triton_call', my_triton_call): | ||
with self.assertRaises(ValueError): | ||
_dummy_fn(jnp.empty((16,))) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |