Skip to content

Commit

Permalink
Removal of cuda hardcoded string with get_device function (#5351)
Browse files Browse the repository at this point in the history
In UTs removed 'cuda' string hardcode by replacing with device variable
set to get_accelerator().device_name()

Co-authored-by: Shaik Raza Sikander <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Apr 13, 2024
1 parent 2c51aba commit f69f884
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 24 deletions.
7 changes: 5 additions & 2 deletions tests/unit/hybrid_engine/test_he_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy.testing as npt
from unit.common import DistributedTest
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.accelerator import get_accelerator

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("This op had not been implemented on this system.", allow_module_level=True)
Expand Down Expand Up @@ -125,7 +126,8 @@ def get_model(self, model_name):
model_config.dropout = 0.0
model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config)
model = model.half()
model = model.to(f'cuda:{local_rank}')
device = get_accelerator().device_name()
model = model.to(f'{device}:{local_rank}')
return model

def get_tokenizer(self, model_name):
Expand Down Expand Up @@ -190,7 +192,8 @@ def test_lora(self, batch_size, model_name, zero_stage, offload_device):

model.train()
batch = tokenizer(train_sentences, max_length=16, padding="max_length", truncation=True, return_tensors="pt")
batch = to_device(batch, f'cuda:{local_rank}')
device = get_accelerator().device_name()
batch = to_device(batch, f'{device}:{local_rank}')
batch["labels"] = batch["input_ids"]
outputs = model(**batch, use_cache=False)
loss = outputs.loss
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/inference/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ class TestStableDiffusion(DistributedTest):
def test(self):
from diffusers import DiffusionPipeline
from image_similarity_measures.quality_metrics import rmse
generator = torch.Generator(device=get_accelerator().current_device())
dev = get_accelerator().device_name()
generator = torch.Generator(device=dev)
seed = 0xABEDABE7
generator.manual_seed(seed)
prompt = "a dog on a rocket"
model = "prompthero/midjourney-v4-diffusion"
local_rank = int(os.getenv("LOCAL_RANK", "0"))
device = torch.device(f"cuda:{local_rank}")

device = torch.device(f"{dev}:{local_rank}")
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half)
pipe = pipe.to(device)
baseline_image = pipe(prompt, guidance_scale=7.5, generator=generator).images[0]
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/ops/transformer/inference/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float
pytest.skip("triton has to be installed for the test")

minus_inf = -65504.0

dev = deepspeed.accelerator.get_accelerator().device_name()
# skip autotune in testing
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
fp16_matmul.skip_autotune()

from deepspeed.ops.transformer.inference.triton.attention import _triton_attention, _triton_packed_flash
torch.manual_seed(20)
q = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
k = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
v = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
q = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=dev).normal_(mean=0, std=.5)
k = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=dev).normal_(mean=0, std=.5)
v = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device=dev).normal_(mean=0, std=.5)
sm_scale = 0.3

# reference implementation
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
score = p
mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device="cuda")
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device=dev)
M = torch.tril(torch.ones((N_CTX, N_CTX), device=dev))
if causal:
for z in range(BATCH):
for h in range(H):
Expand All @@ -58,20 +58,20 @@ def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float
context = ref_out

# adjust it to expected tensor format and run test
qkv = torch.randn((BATCH, N_CTX, 3 * H * D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
qkv = torch.randn((BATCH, N_CTX, 3 * H * D_HEAD), dtype=dtype, device=dev, requires_grad=False)
qkv[:, :, :H * D_HEAD] = q.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD))
qkv[:, :, 1 * H * D_HEAD:2 * H * D_HEAD] = k.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD))
qkv[:, :, 2 * H * D_HEAD:] = v.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD))

if use_flash:
if not get_accelerator().is_triton_supported():
pytest.skip("triton flash attention is supported when the compute capability > 8.0")
triton_mask = torch.zeros((BATCH, 1, 1, N_CTX), dtype=dtype, device="cuda")
triton_mask = torch.zeros((BATCH, 1, 1, N_CTX), dtype=dtype, device=dev)
if not causal:
lengths = torch.randint(N_CTX - 8, N_CTX, (BATCH, 1), device='cuda')
lengths = torch.randint(N_CTX - 8, N_CTX, (BATCH, 1), device=dev)
for i, l in enumerate(lengths):
triton_mask[i, ..., l:] = minus_inf
mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device="cuda")
mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device=dev)
for b in range(BATCH):
mask[b, :, :, lengths[b]:] = minus_inf
ref_out = ref_torch_attention(q, k, v, mask, sm_scale)
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/ops/transformer/inference/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def run_gelu_ds(activations, use_triton_ops=False):
from deepspeed.ops.transformer.inference.triton import gelu
return gelu(activations)

device = deepspeed.accelerator.get_accelerator().device_name()
channels = activations.shape[-1]
bias = torch.zeros((channels), dtype=activations.dtype, device='cuda')
bias = torch.zeros((channels), dtype=activations.dtype, device=device)
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
Expand All @@ -60,7 +61,8 @@ def run_gelu_ds(activations, use_triton_ops=False):
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_gelu(batch, sequence, channels, dtype, use_triton_ops):
activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda')
device = deepspeed.accelerator.get_accelerator().device_name()
activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device)
activations_ref = activations_ds.clone().detach()

if not deepspeed.HAS_TRITON and use_triton_ops:
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/ops/transformer/inference/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,20 @@ def test_layer_norm_residual_store_pre_ln_res(batch, seq_len, channels, dtype):
def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='cuda'):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
dev = get_accelerator().device_name()
torch.manual_seed(0)
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False)
x_bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
weight = torch.rand(w_shape, dtype=dtype, device=dev, requires_grad=False)
bias = torch.rand(w_shape, dtype=dtype, device=dev, requires_grad=False)
x_bias = torch.rand(w_shape, dtype=dtype, device=dev, requires_grad=False)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=dev)
dy = .1 * torch.randn_like(x)
if residual:
res = torch.rand(x_shape, dtype=dtype, device='cuda', requires_grad=False)
res = torch.rand(x_shape, dtype=dtype, device=dev, requires_grad=False)
else:
res = torch.zeros(x_shape, dtype=dtype, device='cuda', requires_grad=False)
res = torch.zeros(x_shape, dtype=dtype, device=dev, requires_grad=False)
x.requires_grad_(True)
# forward pass
if residual or input_bias:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/ops/transformer/inference/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def run_softmax_ds(input, use_triton_ops=False):
def test_softmax(batch, sequence, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda')
device = deepspeed.accelerator.get_accelerator().device_name()
input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=device)
input_ref = input_ds.clone().detach()

ds_out = run_softmax_ds(input_ds, use_triton_ops)
Expand Down

0 comments on commit f69f884

Please sign in to comment.