Skip to content

Commit

Permalink
Move output tensor allocation out of benchmark function for GEMM (#2328)
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
Co-authored-by: Whitney Tsang <[email protected]>
  • Loading branch information
anmyachev and whitneywhtsang authored Sep 24, 2024
1 parent 9ad431a commit 768e5bb
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 44 deletions.
16 changes: 9 additions & 7 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def matmul_kernel_with_block_pointers_batched(


# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
def matmul(a, b):
# and (1) checks any shape constraint; (2) launches the above kernel.
def matmul(a, b, c):
# Check constraints.
if len(a.shape) == 3 and len(b.shape) == 3:
assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension'
Expand All @@ -167,8 +167,6 @@ def matmul(a, b):
assert b.is_contiguous(), 'Matrix B must be contiguous'
B, M, K = a.shape
B, K, N = b.shape
# Allocates output.
c = torch.empty((B, M, N), device=a.device, dtype=torch.float32)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
B,
Expand All @@ -186,8 +184,6 @@ def matmul(a, b):
assert b.is_contiguous(), 'Matrix B must be contiguous'
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_with_block_pointers[grid](
a, b, c, #
Expand Down Expand Up @@ -256,7 +252,13 @@ def benchmark(B, M, N, K, provider):
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
elif provider == 'triton':
triton_fn = lambda: matmul(a, b)
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def matmul_kernel_with_block_pointers_batched(


# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
def matmul(a, b, d):
# and (1) checks any shape constraint; (2) launches the above kernel.
def matmul(a, b, d, c):
# Check constraints.
if len(a.shape) == 3 and len(b.shape) == 3:
assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension'
Expand All @@ -177,8 +177,6 @@ def matmul(a, b, d):
assert b.is_contiguous(), 'Matrix B must be contiguous'
B, M, K = a.shape
B, K, N = b.shape
# Allocates output.
c = torch.empty((B, M, N), device=a.device, dtype=torch.float32)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
B,
Expand All @@ -197,8 +195,6 @@ def matmul(a, b, d):
assert b.is_contiguous(), 'Matrix B must be contiguous'
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_with_block_pointers[grid](
a, b, c, d, #
Expand Down Expand Up @@ -267,7 +263,13 @@ def benchmark(B, M, N, K, provider):
quantiles = [0.5, 0.0, 1.0]

if provider == 'triton':
triton_fn = lambda: matmul(a, b, d)
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, d, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def matmul_kernel_with_block_pointers_batched(


# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
def matmul(a, b):
# and (1) checks any shape constraint; (2) launches the above kernel.
def matmul(a, b, c):
# Check constraints.
if len(a.shape) == 3 and len(b.shape) == 3:
assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension'
Expand All @@ -183,8 +183,6 @@ def matmul(a, b):
assert b.is_contiguous(), 'Matrix B must be contiguous'
B, M, K = a.shape
B, K, N = b.shape
# Allocates output.
c = torch.empty((B, M, N), device=a.device, dtype=torch.float32)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
B,
Expand All @@ -202,8 +200,6 @@ def matmul(a, b):
assert b.is_contiguous(), 'Matrix B must be contiguous'
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_with_block_pointers[grid](
a, b, c, #
Expand Down Expand Up @@ -269,7 +265,13 @@ def benchmark(B, M, N, K, provider):
quantiles = [0.5, 0.0, 1.0]

if provider == 'triton':
triton_fn = lambda: matmul(a, b)
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
Expand Down
16 changes: 9 additions & 7 deletions benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def matmul_kernel_with_block_pointers_batched(


# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
def matmul(a, b):
# and (1) checks any shape constraint; (2) launches the above kernel.
def matmul(a, b, c):
# Check constraints.
if len(a.shape) == 3 and len(b.shape) == 3:
assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension'
Expand All @@ -171,8 +171,6 @@ def matmul(a, b):
assert b.is_contiguous(), 'Matrix B must be contiguous'
B, M, K = a.shape
B, K, N = b.shape
# Allocates output.
c = torch.empty((B, M, N), device=a.device, dtype=torch.float32)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
B,
Expand All @@ -190,8 +188,6 @@ def matmul(a, b):
assert b.is_contiguous(), 'Matrix B must be contiguous'
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_with_block_pointers[grid](
a, b, c, #
Expand Down Expand Up @@ -257,7 +253,13 @@ def benchmark(B, M, N, K, provider):
quantiles = [0.5, 0.0, 1.0]

if provider == 'triton':
triton_fn = lambda: matmul(a, b)
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
if len(a.shape) == 3:
c = torch.empty((B, M, N), device='xpu', dtype=torch.float32)
else:
assert len(a.shape) == 2, 'Expecting shape of length 2'
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
Expand Down
17 changes: 5 additions & 12 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ class _matmul(torch.autograd.Function):
kernel = _kernel

@staticmethod
def _call(a, b, acc_dtype, output_dtype):
device = a.device
def _call(a, b, c, acc_dtype):
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
Expand All @@ -82,12 +81,6 @@ def _call(a, b, acc_dtype, output_dtype):
M, K = a.shape
_, N = b.shape

# allocates output
if output_dtype is None:
output_dtype = torch.float32

c = torch.empty((M, N), device=device, dtype=output_dtype)

# Allowed types for acc_type given the types of a and b.
supported_acc_dtypes = {
torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16),
Expand All @@ -105,7 +98,6 @@ def to_tl_type(ty):
return getattr(tl, str(ty).rsplit('.', maxsplit=1)[-1])

acc_dtype = to_tl_type(acc_dtype)
output_dtype = to_tl_type(output_dtype)

# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
Expand All @@ -119,8 +111,8 @@ def to_tl_type(ty):

# pylint: disable=unused-argument
@staticmethod
def forward(ctx, a, b, acc_dtype=None, output_dtype=None):
return _matmul._call(a, b, acc_dtype=acc_dtype, output_dtype=output_dtype)
def forward(ctx, a, b, c, acc_dtype=None):
return _matmul._call(a, b, c, acc_dtype=acc_dtype)


matmul = _matmul.apply
Expand Down Expand Up @@ -160,7 +152,8 @@ def benchmark(M, N, K, provider):
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
elif provider == 'triton':
triton_fn = lambda: matmul(a, b)
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
Expand Down
7 changes: 3 additions & 4 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def full_tiles(
# ---------------------------------------------------------------------------


def matmul(a: torch.Tensor, b: torch.Tensor):
def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
num_xe_core = torch.xpu.get_device_capability(0)['gpu_subslice_count']
streamk_programs = num_xe_core

Expand Down Expand Up @@ -226,8 +226,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
streamk_full_tiles = streamk_iters // streamk_programs
streamk_partial_tiles = streamk_iters % streamk_programs

# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
first_wave[(streamk_programs, )](
a, b, c, #
M, N, K, #
Expand Down Expand Up @@ -276,7 +274,8 @@ def benchmark(M, N, K, provider):
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
elif provider == 'triton':
triton_fn = lambda: matmul(a, b)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
Expand Down

0 comments on commit 768e5bb

Please sign in to comment.