Skip to content

Commit

Permalink
fix and revert unnecessary changes
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev committed Sep 7, 2024
1 parent e6e3bbc commit ef78aa4
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 64 deletions.
1 change: 1 addition & 0 deletions benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark # type: ignore # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

import triton
import triton.language as tl

import triton_kernels_benchmark
from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module

from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark
benchmark_suit = triton_kernels_benchmark # triton.testing


# pylint: disable=unused-argument
Expand Down Expand Up @@ -182,8 +184,8 @@ def forward(q, k, v, causal, sm_scale):
return o


@perf_report(
Benchmark(
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['Z', 'H', 'N_CTX', 'D_HEAD'],
x_vals=[ #
Expand Down Expand Up @@ -217,7 +219,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
sm_scale = 0.125
quantiles = [0.5, 0.0, 1.0]
if provider == 'onednn':
_, min_ms, max_ms, mean, cv = do_bench(
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
False, scale=sm_scale), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
Expand All @@ -227,13 +229,15 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)

elif provider == 'xetla':
func = getattr(xetla_kernel, 'flash_attn')
xetla_fn = lambda: func(Z, H, D_HEAD, N_CTX, N_CTX)
_, min_ms, max_ms, mean, cv = do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)

else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down
60 changes: 32 additions & 28 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import triton
import triton.language as tl
from triton.runtime import driver

import triton_kernels_benchmark
from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module

from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark
benchmark_suit = triton_kernels_benchmark # triton.testing


@torch.jit.script
Expand Down Expand Up @@ -102,49 +104,51 @@ def softmax(x):
return y


@perf_report(
Benchmark(x_names=["N"], # argument names to use as an x-axis for the plot
x_vals=[256, 1024, 2048, 4096, 1024 * 8, 1024 * 16, 1024 * 32], # different possible values for `x_name`
line_arg="provider", # argument name whose value corresponds to a different line in the plot
line_vals=[
"triton",
# "torch-native",
# "torch-jit",
"xetla",
], # possible values for `line_arg``
line_names=[
"Triton",
# "Torch (native)",
# "Torch (jit)",
"XeTLA",
], # label name for the lines
styles=[("blue", "-"), ("green", "-"), ("green", "--"), ("black", ":")], # line styles
ylabel=["GB/s", "TFlops"], # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={"M": 4096}, # values for function arguments not in `x_names` and `y_name`
))
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
x_names=["N"], # argument names to use as an x-axis for the plot
x_vals=[256, 1024, 2048, 4096, 1024 * 8, 1024 * 16, 1024 * 32], # different possible values for `x_name`
line_arg="provider", # argument name whose value corresponds to a different line in the plot
line_vals=[
"triton",
# "torch-native",
# "torch-jit",
"xetla",
], # possible values for `line_arg``
line_names=[
"Triton",
# "Torch (native)",
# "Torch (jit)",
"XeTLA",
], # label name for the lines
styles=[("blue", "-"), ("green", "-"), ("green", "--"), ("black", ":")], # line styles
ylabel=["GB/s", "TFlops"], # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={"M": 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device="xpu", dtype=torch.bfloat16)
quantiles = [0.5, 0.0, 1.0]
if provider == "torch-native":
_, min_ms, max_ms, mean, cv = do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, warmup=10,
rep=10)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
warmup=10, rep=10)
if provider == "triton":
triton_fn = lambda: softmax(x)
torch_fn = lambda: torch.softmax(x, axis=-1)
assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10)

elif provider == "torch-jit":
_, min_ms, max_ms, mean, cv = do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, rep=10)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10,
rep=10)

elif provider == "xetla":
name = f"softmax_shape_{M}_{N}"
func = getattr(xetla_kernel, name)
xetla_fn = lambda: func(x, 0)
torch_fn = lambda: torch.softmax(x, axis=-1)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch")
_, min_ms, max_ms, mean, cv = do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10)

else:
raise NotImplementedError(f"Unsupported provider {provider}")
Expand Down
20 changes: 11 additions & 9 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

import triton
import triton.language as tl
from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module

from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark
import triton_kernels_benchmark as benchmark_suit
from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module


@triton.autotune(
Expand Down Expand Up @@ -199,8 +199,8 @@ def matmul(a, b):


# Benchmark Performance
@perf_report(
Benchmark(
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['B', 'M', 'K', 'N'],
# different possible values for `x_name`
Expand Down Expand Up @@ -249,14 +249,15 @@ def benchmark(B, M, N, K, provider):
quantiles = [0.5, 0.0, 1.0]

if provider == 'onednn':
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
elif provider == 'triton':
triton_fn = lambda: matmul(a, b)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
elif provider == 'xetla':
if B == 1:
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
Expand All @@ -271,7 +272,8 @@ def benchmark(B, M, N, K, provider):
xetla_fn = lambda: func(a, b, c, acc, cnt)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
11 changes: 6 additions & 5 deletions benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import triton
import triton.language as tl

from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark
import triton_kernels_benchmark as benchmark_suit


@triton.autotune(
Expand Down Expand Up @@ -204,8 +204,8 @@ def matmul(a, b):


# Benchmark Performance
@perf_report(
Benchmark(
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['B', 'M', 'K', 'N'],
# different possible values for `x_name`
Expand Down Expand Up @@ -259,8 +259,9 @@ def benchmark(B, M, N, K, provider):
triton_fn = lambda: matmul(a, b)
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
17 changes: 10 additions & 7 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import triton
import triton.language as tl

from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark
import triton_kernels_benchmark

benchmark_suit = triton_kernels_benchmark # triton.testing


@triton.autotune(
Expand Down Expand Up @@ -126,8 +128,8 @@ def forward(ctx, a, b, acc_dtype=None, output_dtype=None):


# Benchmark Performance
@perf_report(
Benchmark(
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['M', 'K', 'N'],
x_vals=[
Expand Down Expand Up @@ -156,14 +158,15 @@ def benchmark(M, N, K, provider):
quantiles = [0.5, 0.0, 1.0]

if provider == 'onednn':
_, min_ms, max_ms, mean, cv = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
elif provider == 'triton':
triton_fn = lambda: matmul(a, b)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
17 changes: 10 additions & 7 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import triton
import triton.language as tl

from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark
import triton_kernels_benchmark

benchmark_suit = triton_kernels_benchmark # triton.testing


# pylint: disable=unused-argument
Expand Down Expand Up @@ -246,8 +248,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor):


# Benchmark Performance
@perf_report(
Benchmark(
@benchmark_suit.perf_report(
benchmark_suit.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['M', 'K', 'N'],
x_vals=[[3072, 4096, 3072]],
Expand All @@ -272,13 +274,14 @@ def benchmark(M, N, K, provider):
quantiles = [0.5, 0.0, 1.0]

if provider == 'onednn':
_, min_ms, max_ms, mean, cv = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
quantiles=quantiles, fast_flush=False)
elif provider == 'triton':
triton_fn = lambda: matmul(a, b)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
4 changes: 3 additions & 1 deletion benchmarks/xetla_kernel/python_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
#include <CL/sycl.hpp>
#include <c10/core/ScalarType.h>
#include <cstdint>
#include <torch/extension.h>

#ifdef USE_IPEX
// `#include <ipex.h>` should be before `#include <torch/extension.h>`
#include <ipex.h>
#else
#include <c10/xpu/XPUStream.h>
#endif

#include <torch/extension.h>

sycl::queue get_current_sycl_queue() {
// submit kernel
c10::impl::VirtualGuardImpl impl(at::DeviceType::XPU);
Expand Down

0 comments on commit ef78aa4

Please sign in to comment.