Skip to content

Commit

Permalink
Inference UTs check for trition support from accelerator
Browse files Browse the repository at this point in the history
  • Loading branch information
raza-sikander committed Nov 25, 2024
1 parent f57b1ef commit 984d3c7
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def ref_torch_attention(q, k, v, mask, sm_scale):
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("use_flash", [True, False])
def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float16):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

minus_inf = -65504.0
dev = deepspeed.accelerator.get_accelerator().device_name()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def test_gelu(batch, sequence, channels, dtype, use_triton_ops):
activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device)
activations_ref = activations_ds.clone().detach()

if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")
ds_out = run_gelu_ds(activations_ds, use_triton_ops)
ref_out = run_gelu_reference(activations_ref)
assert (allclose(ds_out, ref_out))
12 changes: 6 additions & 6 deletions tests/unit/ops/transformer/inference/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def ds_triton_implementation(vals, gamma, beta, epsilon):
@pytest.mark.parametrize("dtype", get_dtypes())
@pytest.mark.parametrize("use_triton_ops", [False, True])
def test_layer_norm(batch, seq_len, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name())
Expand Down Expand Up @@ -93,8 +93,8 @@ def residual_ds_triton_implementation(vals, bias, res, gamma, beta, epsilon):
@pytest.mark.parametrize("dtype", get_dtypes())
@pytest.mark.parametrize("use_triton_ops", [False, True])
def test_layer_norm_residual(batch, seq_len, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
residual = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
Expand Down Expand Up @@ -163,8 +163,8 @@ def test_layer_norm_residual_store_pre_ln_res(batch, seq_len, channels, dtype):
@pytest.mark.parametrize("residual", [True, False])
@pytest.mark.parametrize("input_bias", [True, False])
def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='cuda'):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")
dev = get_accelerator().device_name()
torch.manual_seed(0)
# create data
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def run_matmul_ds(a, b, use_triton_ops=False):
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("use_triton_ops", [True])
def test_matmul_4d(B, H, M, K, N, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

# skip autotune in testing
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_residual_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f
@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_residual_add(batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, pre_attn_norm,
use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")
ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def run_softmax_ds(input, use_triton_ops=False):
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_triton_ops", [True])
def test_softmax(batch, sequence, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
if not deepspeed.get_accelerator().is_triton_supported():
pytest.skip("triton is not supported on this system")

device = deepspeed.accelerator.get_accelerator().device_name()
input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device)
Expand Down

0 comments on commit 984d3c7

Please sign in to comment.