Skip to content

Commit

Permalink
Cleanup ops/transformer/inference tests (#6925)
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 6, 2025
1 parent b0040b6 commit c348c5b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 32 deletions.
42 changes: 14 additions & 28 deletions tests/unit/ops/transformer/inference/inference_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

# DeepSpeed Team

from typing import Tuple

import torch
from deepspeed.accelerator import get_accelerator

Expand All @@ -23,38 +25,22 @@ def get_tolerances():
DTYPES = None


def get_dtypes():
def get_dtypes(include_float=True):
global DTYPES
if DTYPES is None:
DTYPES = get_accelerator().supported_dtypes()
DTYPES = [torch.float16, torch.float32] if include_float else [torch.float16]
try:
if get_accelerator().is_bf16_supported():
DTYPES.append(torch.bfloat16)
except (AssertionError, AttributeError):
pass
return DTYPES


def allclose(x, y):
def allclose(x, y, tolerances: Tuple[int, int] = None):
assert x.dtype == y.dtype
rtol, atol = get_tolerances()[x.dtype]
if tolerances is None:
rtol, atol = get_tolerances()[x.dtype]
else:
rtol, atol = tolerances
return torch.allclose(x, y, rtol=rtol, atol=atol)


def assert_almost_equal(x, y, decimal=2, err_msg=''):
import numpy.testing as npt
if isinstance(x, torch.Tensor):
if x.dtype == torch.bfloat16:
x = x.float()
x = x.cpu().detach().numpy()
if isinstance(y, torch.Tensor):
if y.dtype == torch.bfloat16:
y = y.float()
y = y.cpu().detach().numpy()
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)


def max_diff(a, b):
a = a.to(torch.float32).flatten()
b = b.to(torch.float32).flatten()
diff = torch.abs(a - b)
max_diff_indices = torch.argsort(diff)[-1]
print("Max difference indices:", max_diff_indices)
print("Max difference values:", diff[max_diff_indices])
print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}")
return max_diff_indices
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import deepspeed
from deepspeed.accelerator import get_accelerator
from .inference_test_utils import assert_almost_equal
from .inference_test_utils import allclose


# reference timplementation
Expand Down Expand Up @@ -88,4 +88,4 @@ def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float
use_triton_flash=False,
use_ds_attention=False)
tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3)
assert_almost_equal(ref_out, tri_out)
assert (allclose(ref_out, tri_out))
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
from .inference_test_utils import allclose, get_dtypes, assert_almost_equal
from .inference_test_utils import allclose, get_dtypes
try:
import triton # noqa: F401 # type: ignore
from deepspeed.ops.transformer.inference.triton import (
Expand Down Expand Up @@ -188,4 +188,4 @@ def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='
y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias,
eps).to(dtype)
# compare
assert_almost_equal(y_tri, y_ref)
assert (allclose(y_tri, y_ref))

0 comments on commit c348c5b

Please sign in to comment.