From c348c5b11a4fd3f70a53f5c9445f260472de47a8 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:35:50 -0800 Subject: [PATCH] Cleanup ops/transformer/inference tests (#6925) --- .../inference/inference_test_utils.py | 42 +++++++------------ .../transformer/inference/test_attention.py | 4 +- .../transformer/inference/test_layer_norm.py | 4 +- 3 files changed, 18 insertions(+), 32 deletions(-) diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py index 9cfcae809f09..d63c51267e51 100644 --- a/tests/unit/ops/transformer/inference/inference_test_utils.py +++ b/tests/unit/ops/transformer/inference/inference_test_utils.py @@ -3,6 +3,8 @@ # DeepSpeed Team +from typing import Tuple + import torch from deepspeed.accelerator import get_accelerator @@ -23,38 +25,22 @@ def get_tolerances(): DTYPES = None -def get_dtypes(): +def get_dtypes(include_float=True): global DTYPES if DTYPES is None: - DTYPES = get_accelerator().supported_dtypes() + DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16] + try: + if get_accelerator().is_bf16_supported(): + DTYPES.append(torch.bfloat16) + except (AssertionError, AttributeError): + pass return DTYPES -def allclose(x, y): +def allclose(x, y, tolerances: Tuple[int, int] = None): assert x.dtype == y.dtype - rtol, atol = get_tolerances()[x.dtype] + if tolerances is None: + rtol, atol = get_tolerances()[x.dtype] + else: + rtol, atol = tolerances return torch.allclose(x, y, rtol=rtol, atol=atol) - - -def assert_almost_equal(x, y, decimal=2, err_msg=''): - import numpy.testing as npt - if isinstance(x, torch.Tensor): - if x.dtype == torch.bfloat16: - x = x.float() - x = x.cpu().detach().numpy() - if isinstance(y, torch.Tensor): - if y.dtype == torch.bfloat16: - y = y.float() - y = y.cpu().detach().numpy() - npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) - - -def max_diff(a, b): - a = a.to(torch.float32).flatten() - b = b.to(torch.float32).flatten() - diff = torch.abs(a - b) - max_diff_indices = torch.argsort(diff)[-1] - print("Max difference indices:", max_diff_indices) - print("Max difference values:", diff[max_diff_indices]) - print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}") - return max_diff_indices diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py index ecf681542ff6..cae201d747a3 100644 --- a/tests/unit/ops/transformer/inference/test_attention.py +++ b/tests/unit/ops/transformer/inference/test_attention.py @@ -7,7 +7,7 @@ import torch import deepspeed from deepspeed.accelerator import get_accelerator -from .inference_test_utils import assert_almost_equal +from .inference_test_utils import allclose # reference timplementation @@ -88,4 +88,4 @@ def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float use_triton_flash=False, use_ds_attention=False) tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3) - assert_almost_equal(ref_out, tri_out) + assert (allclose(ref_out, tri_out)) diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py index 7711daf0d887..4a84add16046 100644 --- a/tests/unit/ops/transformer/inference/test_layer_norm.py +++ b/tests/unit/ops/transformer/inference/test_layer_norm.py @@ -9,7 +9,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp -from .inference_test_utils import allclose, get_dtypes, assert_almost_equal +from .inference_test_utils import allclose, get_dtypes try: import triton # noqa: F401 # type: ignore from deepspeed.ops.transformer.inference.triton import ( @@ -188,4 +188,4 @@ def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device=' y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias, eps).to(dtype) # compare - assert_almost_equal(y_tri, y_ref) + assert (allclose(y_tri, y_ref))