diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index bfc00eb01a..ecb6d8ca71 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -157,8 +157,8 @@ def matmul_kernel_with_block_pointers_batched( # We can now create a convenience wrapper function that only takes two input tensors, -# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. -def matmul(a, b): +# and (1) checks any shape constraint; (2) launches the above kernel. +def matmul(a, b, c): # Check constraints. if len(a.shape) == 3 and len(b.shape) == 3: assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension' @@ -167,8 +167,6 @@ def matmul(a, b): assert b.is_contiguous(), 'Matrix B must be contiguous' B, M, K = a.shape B, K, N = b.shape - # Allocates output. - c = torch.empty((B, M, N), device=a.device, dtype=torch.float32) # 1D launch kernel where each block gets its own program. grid = lambda META: ( B, @@ -186,8 +184,6 @@ def matmul(a, b): assert b.is_contiguous(), 'Matrix B must be contiguous' M, K = a.shape K, N = b.shape - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel_with_block_pointers[grid]( a, b, c, # @@ -256,7 +252,13 @@ def benchmark(B, M, N, K, provider): _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles, fast_flush=False) elif provider == 'triton': - triton_fn = lambda: matmul(a, b) + assert len(a.shape) == len(b.shape), 'Incompatible sizes' + if len(a.shape) == 3: + c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + else: + assert len(a.shape) == 2, 'Expecting shape of length 2' + c = torch.empty((M, N), device='xpu', dtype=torch.float32) + triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 6add18b328..fa3f9d5ae2 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -167,8 +167,8 @@ def matmul_kernel_with_block_pointers_batched( # We can now create a convenience wrapper function that only takes two input tensors, -# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. -def matmul(a, b, d): +# and (1) checks any shape constraint; (2) launches the above kernel. +def matmul(a, b, d, c): # Check constraints. if len(a.shape) == 3 and len(b.shape) == 3: assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension' @@ -177,8 +177,6 @@ def matmul(a, b, d): assert b.is_contiguous(), 'Matrix B must be contiguous' B, M, K = a.shape B, K, N = b.shape - # Allocates output. - c = torch.empty((B, M, N), device=a.device, dtype=torch.float32) # 1D launch kernel where each block gets its own program. grid = lambda META: ( B, @@ -197,8 +195,6 @@ def matmul(a, b, d): assert b.is_contiguous(), 'Matrix B must be contiguous' M, K = a.shape K, N = b.shape - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel_with_block_pointers[grid]( a, b, c, d, # @@ -267,7 +263,13 @@ def benchmark(B, M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'triton': - triton_fn = lambda: matmul(a, b, d) + assert len(a.shape) == len(b.shape), 'Incompatible sizes' + if len(a.shape) == 3: + c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + else: + assert len(a.shape) == 2, 'Expecting shape of length 2' + c = torch.empty((M, N), device='xpu', dtype=torch.float32) + triton_fn = lambda: matmul(a, b, d, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 9953cf5357..ad2efc7e6d 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -173,8 +173,8 @@ def matmul_kernel_with_block_pointers_batched( # We can now create a convenience wrapper function that only takes two input tensors, -# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. -def matmul(a, b): +# and (1) checks any shape constraint; (2) launches the above kernel. +def matmul(a, b, c): # Check constraints. if len(a.shape) == 3 and len(b.shape) == 3: assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension' @@ -183,8 +183,6 @@ def matmul(a, b): assert b.is_contiguous(), 'Matrix B must be contiguous' B, M, K = a.shape B, K, N = b.shape - # Allocates output. - c = torch.empty((B, M, N), device=a.device, dtype=torch.float32) # 1D launch kernel where each block gets its own program. grid = lambda META: ( B, @@ -202,8 +200,6 @@ def matmul(a, b): assert b.is_contiguous(), 'Matrix B must be contiguous' M, K = a.shape K, N = b.shape - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel_with_block_pointers[grid]( a, b, c, # @@ -269,7 +265,13 @@ def benchmark(B, M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'triton': - triton_fn = lambda: matmul(a, b) + assert len(a.shape) == len(b.shape), 'Incompatible sizes' + if len(a.shape) == 3: + c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + else: + assert len(a.shape) == 2, 'Expecting shape of length 2' + c = torch.empty((M, N), device='xpu', dtype=torch.float32) + triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32)) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 1ed4f8472e..7456f2a2f1 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -161,8 +161,8 @@ def matmul_kernel_with_block_pointers_batched( # We can now create a convenience wrapper function that only takes two input tensors, -# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. -def matmul(a, b): +# and (1) checks any shape constraint; (2) launches the above kernel. +def matmul(a, b, c): # Check constraints. if len(a.shape) == 3 and len(b.shape) == 3: assert a.shape[0] == b.shape[0], 'Incompatible Batch dimension' @@ -171,8 +171,6 @@ def matmul(a, b): assert b.is_contiguous(), 'Matrix B must be contiguous' B, M, K = a.shape B, K, N = b.shape - # Allocates output. - c = torch.empty((B, M, N), device=a.device, dtype=torch.float32) # 1D launch kernel where each block gets its own program. grid = lambda META: ( B, @@ -190,8 +188,6 @@ def matmul(a, b): assert b.is_contiguous(), 'Matrix B must be contiguous' M, K = a.shape K, N = b.shape - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel_with_block_pointers[grid]( a, b, c, # @@ -257,7 +253,13 @@ def benchmark(B, M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'triton': - triton_fn = lambda: matmul(a, b) + assert len(a.shape) == len(b.shape), 'Incompatible sizes' + if len(a.shape) == 3: + c = torch.empty((B, M, N), device='xpu', dtype=torch.float32) + else: + assert len(a.shape) == 2, 'Expecting shape of length 2' + c = torch.empty((M, N), device='xpu', dtype=torch.float32) + triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 5fa957f69d..2146c0434e 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -70,8 +70,7 @@ class _matmul(torch.autograd.Function): kernel = _kernel @staticmethod - def _call(a, b, acc_dtype, output_dtype): - device = a.device + def _call(a, b, c, acc_dtype): # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous() @@ -82,12 +81,6 @@ def _call(a, b, acc_dtype, output_dtype): M, K = a.shape _, N = b.shape - # allocates output - if output_dtype is None: - output_dtype = torch.float32 - - c = torch.empty((M, N), device=device, dtype=output_dtype) - # Allowed types for acc_type given the types of a and b. supported_acc_dtypes = { torch.float16: (torch.float32, torch.float16), torch.bfloat16: (torch.float32, torch.bfloat16), @@ -105,7 +98,6 @@ def to_tl_type(ty): return getattr(tl, str(ty).rsplit('.', maxsplit=1)[-1]) acc_dtype = to_tl_type(acc_dtype) - output_dtype = to_tl_type(output_dtype) # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) @@ -119,8 +111,8 @@ def to_tl_type(ty): # pylint: disable=unused-argument @staticmethod - def forward(ctx, a, b, acc_dtype=None, output_dtype=None): - return _matmul._call(a, b, acc_dtype=acc_dtype, output_dtype=output_dtype) + def forward(ctx, a, b, c, acc_dtype=None): + return _matmul._call(a, b, c, acc_dtype=acc_dtype) matmul = _matmul.apply @@ -160,7 +152,8 @@ def benchmark(M, N, K, provider): _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles, fast_flush=False) elif provider == 'triton': - triton_fn = lambda: matmul(a, b) + c = torch.empty((M, N), device='xpu', dtype=torch.float32) + triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 7eb0b651f6..50a1f2ea32 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -194,7 +194,7 @@ def full_tiles( # --------------------------------------------------------------------------- -def matmul(a: torch.Tensor, b: torch.Tensor): +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): num_xe_core = torch.xpu.get_device_capability(0)['gpu_subslice_count'] streamk_programs = num_xe_core @@ -226,8 +226,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor): streamk_full_tiles = streamk_iters // streamk_programs streamk_partial_tiles = streamk_iters % streamk_programs - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=torch.float32) first_wave[(streamk_programs, )]( a, b, c, # M, N, K, # @@ -276,7 +274,8 @@ def benchmark(M, N, K, provider): _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles, fast_flush=False) elif provider == 'triton': - triton_fn = lambda: matmul(a, b) + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + triton_fn = lambda: matmul(a, b, c) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,