From 1d168117a0832f625dd2246554a3d8b6e096f705 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Fri, 6 Sep 2024 19:35:22 +0000 Subject: [PATCH 01/13] Remove usage of 'torch.autograd.profiler_legacy' for benchmarks Signed-off-by: Anatoly Myachev --- benchmarks/CMakeLists.txt | 1 - benchmarks/setup.py | 3 +- .../triton_kernels_benchmark/__init__.py | 6 - .../benchmark_driver.py | 434 ------------------ .../benchmark_testing.py | 76 +-- .../flash_attention_fwd_benchmark.py | 18 +- .../triton_kernels_benchmark/fused_softmax.py | 60 ++- .../gemm_benchmark.py | 20 +- .../gemm_preop_exp_benchmark.py | 11 +- .../gemm_splitk_benchmark.py | 17 +- .../gemm_streamk_benchmark.py | 17 +- benchmarks/xetla_kernel/CMakeLists.txt | 2 - benchmarks/xetla_kernel/python_main.cpp | 11 +- 13 files changed, 74 insertions(+), 602 deletions(-) delete mode 100644 benchmarks/triton_kernels_benchmark/benchmark_driver.py diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index 03cb418132..473ca8cd04 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -11,7 +11,6 @@ endif() find_package(Python3 COMPONENTS Interpreter) find_package(Torch REQUIRED) find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") -find_package(IPEX REQUIRED) # add the XeTLA kernel. diff --git a/benchmarks/setup.py b/benchmarks/setup.py index 08f76e21a7..46b30c37ab 100644 --- a/benchmarks/setup.py +++ b/benchmarks/setup.py @@ -8,7 +8,6 @@ from setuptools import setup import torch -import intel_extension_for_pytorch class CMakeBuild(): @@ -43,7 +42,7 @@ def build_extension(self): "Ninja", # Ninja is much faster than make "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path - f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path};{intel_extension_for_pytorch.cmake_prefix_path}", + f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}", "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=" + self.extdir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + self.extdir, diff --git a/benchmarks/triton_kernels_benchmark/__init__.py b/benchmarks/triton_kernels_benchmark/__init__.py index 01712272e1..e69de29bb2 100644 --- a/benchmarks/triton_kernels_benchmark/__init__.py +++ b/benchmarks/triton_kernels_benchmark/__init__.py @@ -1,6 +0,0 @@ -from triton.runtime import driver -from . import benchmark_driver -from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark # type: ignore # noqa: F401 - -# replace the launcher with the profilier hook. -driver.active.launcher_cls = benchmark_driver.XPULauncher diff --git a/benchmarks/triton_kernels_benchmark/benchmark_driver.py b/benchmarks/triton_kernels_benchmark/benchmark_driver.py deleted file mode 100644 index 15bc488b14..0000000000 --- a/benchmarks/triton_kernels_benchmark/benchmark_driver.py +++ /dev/null @@ -1,434 +0,0 @@ -import os -import hashlib -import importlib.util -import tempfile -from pathlib import Path - -from triton.backends.compiler import GPUTarget -from triton.backends.driver import DriverBase -from triton.runtime.cache import get_cache_manager -from triton.runtime.build import _build, quiet - -import torch -import intel_extension_for_pytorch - -_dirname = os.getenv("ZE_PATH", default="/usr/local") - -include_dir = [ - os.path.join(_dirname, "include"), - os.path.join(torch.utils.cmake_prefix_path, "../../include"), - os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include"), - os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../include") -] - -oneapi_root = os.getenv("ONEAPI_ROOT") -if oneapi_root: - include_dir += [ - os.path.join(oneapi_root, "compiler/latest/include"), - os.path.join(oneapi_root, "compiler/latest/include/sycl") - ] - -library_dir = [ - os.path.join(_dirname, "lib"), - os.path.join(torch.utils.cmake_prefix_path, "../../lib"), - os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../lib") -] -libraries = ["ze_loader", "sycl", "torch", "intel-ext-pt-gpu"] - - -def compile_module_from_src(src, name): - key = hashlib.sha256(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.cpp") - with open(src_path, "w", encoding="utf-8") as f: - f.write(src) - with quiet(): - so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) - spec = importlib.util.spec_from_file_location(name, cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod - - -# ------------------------ -# Utils -# ------------------------ - - -class XPUUtils: - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(XPUUtils, cls).__new__(cls) - return cls.instance - - def __init__(self): - dirname = os.path.dirname(os.path.realpath(__file__)) - mod = compile_module_from_src( - Path(os.path.join(dirname, "driver.c")).read_text(encoding="utf-8"), "spirv_utils") - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - self.context = mod.init_context(self.get_sycl_queue()) - self.device_count = mod.init_devices(self.get_sycl_queue()) - self.current_device = 0 if self.device_count[0] > 0 else -1 - - def get_current_device(self): - return self.current_device - - def get_sycl_queue(self): - return torch.xpu.current_stream().sycl_queue - - -# ------------------------ -# Launcher -# ------------------------ - - -def ty_to_cpp(ty): - if ty[0] == "*": - return "void*" - return { - "i1": "int32_t", - "i8": "int8_t", - "i16": "int16_t", - "i32": "int32_t", - "i64": "int64_t", - "u1": "uint32_t", - "u8": "uint8_t", - "u16": "uint16_t", - "u32": "uint32_t", - "u64": "uint64_t", - "fp16": "float", - "bf16": "float", - "fp32": "float", - "f32": "float", - "fp64": "double", - }[ty] - - -def make_launcher(constants, signature, ids): # pylint: disable=unused-argument - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors. - arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - - def _extracted_type(ty): - if ty[0] == "*": - return "PyObject*" - return ty_to_cpp(ty) - - def format_of(ty): - return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "l", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", - }[ty] - - args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()]) - fmt = "iiiOOOOOO" + args_format - args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else "" - - # generate glue code - src = f""" - #include - #include - #include - #include - #include - #include - #include - #include - - #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - #include - #include - #include - - static inline void gpuAssert(ze_result_t code, const char *file, int line) - {{ - if (code != ZE_RESULT_SUCCESS) - {{ - const char* prefix = "Triton Error [ZE]: "; - std::string str = std::to_string(code); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str.c_str()); - PyErr_SetString(PyExc_RuntimeError, err); - }} - }} - - #define ZE_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - - typedef struct _DevicePtrInfo {{ - void* dev_ptr; - bool valid; - }} DevicePtrInfo; - - static inline void checkDevicePointer(DevicePtrInfo *ptr_info, int idx, const sycl::queue &queue) {{ - if (!ptr_info->dev_ptr || !ptr_info->valid) {{ - return; - }} - auto context = queue.get_context(); - auto handle = sycl::get_native(context); - ze_memory_allocation_properties_t prop; - prop.stype = ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES; - prop.pNext = nullptr; - ze_device_handle_t device; - auto res = zeMemGetAllocProperties((ze_context_handle_t)handle, ptr_info->dev_ptr, &prop, &device); - if (res != ZE_RESULT_SUCCESS) {{ - PyErr_Format(PyExc_ValueError, - "Cannot get memory properties for pointer argument (at %d, err=%d)", idx, res); - ptr_info->valid = false; - }} else if (prop.type != ZE_MEMORY_TYPE_DEVICE) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) doesn't reference XPU device memory (cpu tensor?)", idx); - ptr_info->valid = false; - }} - }} - - static inline DevicePtrInfo getPointer(PyObject *obj, int idx, const sycl::queue &queue) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); - checkDevicePointer(&ptr_info, idx, queue); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); - if(!ptr_info.dev_ptr) {{ - return ptr_info; - }} - checkDevicePointer(&ptr_info, idx, queue); - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; - }} -// start sycl - static void set_scalar_arg( - sycl::handler& cgh, - int index, - size_t size, - const void* value) {{ - switch (size) {{ - case sizeof(uint8_t): - cgh.set_arg(index, *static_cast(value)); - break; - case sizeof(uint16_t): - cgh.set_arg(index, *static_cast(value)); - break; - case sizeof(uint32_t): - cgh.set_arg(index, *static_cast(value)); - break; - case sizeof(uint64_t): - cgh.set_arg(index, *static_cast(value)); - break; - default: - assert(false && "wrong scalar size in sycl gen."); - }} - }} - static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {", " + arg_decls if len(arg_decls) > 0 else ""}) {{ - - std::string kernel_name = kernel_ptr.get_info(); - RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}}); - void *params[] = {{ {", ".join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - uint32_t num_params = sizeof(params)/sizeof(params[0]); - uint32_t expected_num_params = kernel_ptr.get_info(); - size_t global_range_x = gridX*threads_per_warp*num_warps; - size_t global_range_y = gridY; - size_t global_range_z = gridZ; - size_t local_range_x = num_warps*threads_per_warp; - size_t local_range_y = 1; - size_t local_range_z = 1; - sycl::range<3> global_range(global_range_z, global_range_y, global_range_x); - sycl::range<3> local_range(local_range_z, local_range_y, local_range_x); - sycl::nd_range<3> parallel_work_size(global_range, local_range); - if (shared_memory) {{ - expected_num_params -= 1; - }} - assert(num_params == expected_num_params && "number of kernel param not matched"); - // Submit the imported kernel. - auto cgf = [&](sycl::handler &cgh) {{ - {" ".join(f"set_scalar_arg(cgh, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants]))} - if (shared_memory) {{ - using share_mem_t = sycl::local_accessor; - share_mem_t local_buffer = share_mem_t(shared_memory, cgh); - cgh.set_arg(num_params, local_buffer); - cgh.parallel_for(parallel_work_size, kernel_ptr); - }} else {{ - cgh.parallel_for(parallel_work_size, kernel_ptr); - }} - }}; - auto event = stream.submit(cgf); - xpu::profiler_record(kernel_name, event); - }} -// end sycl - static PyObject* launch(PyObject* self, PyObject* args) {{ - - int gridX, gridY, gridZ; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *kernel_metadata = NULL; - PyObject *launch_metadata = NULL; - PyObject *py_obj_stream; - PyObject *py_kernel; - - {" ".join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{fmt}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook {args_list})) {{ - return NULL; - }} - - // extract kernel metadata - int num_warps = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_warps")); - int num_ctas = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_ctas")); - int shared_memory = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "shared")); - int threads_per_warp = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "threads_per_warp")); - - // extract cluster dims - PyObject *clusterDim = PyObject_GetAttrString(kernel_metadata, "cluster_dims"); - if (!PyTuple_Check(kernel_metadata)) {{ - PyErr_SetString(PyExc_TypeError, "kernel_metadata.cluster_dims must be a tuple"); - return NULL; - }} - int clusterDimX = PyLong_AsLong(PyTuple_GetItem(clusterDim, 0)); - int clusterDimY = PyLong_AsLong(PyTuple_GetItem(clusterDim, 1)); - int clusterDimZ = PyLong_AsLong(PyTuple_GetItem(clusterDim, 2)); - // extract launch metadata - if (launch_enter_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - void * pStream = PyLong_AsVoidPtr(py_obj_stream); - //error check - if(pStream == nullptr || py_kernel == nullptr) return NULL; - - sycl::queue stream = *(static_cast(pStream)); - sycl::kernel* kernel_ptr = reinterpret_cast(PyCapsule_GetPointer(py_kernel, "kernel")); - if(kernel_ptr == nullptr) return NULL; - sycl::kernel kernel = *kernel_ptr; - - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""}); - - if(launch_exit_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - if (PyErr_Occurred()) {{ - return NULL; - }} - - // return None - Py_INCREF(Py_None); - return Py_None; - }} - - static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel - }}; - - static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods - }}; - - PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; - }} - """ - return src - - -class XPULauncher: - - def __init__(self, src, metadata): # pylint: disable=unused-argument - ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} - constants = src.constants if hasattr(src, "constants") else {} - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} - src = make_launcher(constants, signature, ids) - mod = compile_module_from_src(src, "__triton_launcher") - self.launch = mod.launch - - def __call__(self, *args, **kwargs): - self.launch(*args, **kwargs) - - -class XPUDriver(DriverBase): - - def __init__(self): - self.launcher_cls = XPULauncher - - def __getattr__(self, name): - # Lazily initialize utils to avoid unnecessary XPU runtime invocations. - # See https://github.com/intel/intel-xpu-backend-for-triton/issues/624 - if name == "utils": - self.utils = XPUUtils() # pylint: disable=attribute-defined-outside-init - return self.utils - raise AttributeError - - def get_current_device(self): - return self.utils.get_current_device() - - def get_current_stream(self, device): # pylint: disable=unused-argument - return torch.xpu.current_stream().sycl_queue - - def get_current_target(self): - device = self.get_current_device() - dev_property = torch.xpu.get_device_capability(device) - warp_size = 32 - return GPUTarget("xpu", dev_property, warp_size) - - @staticmethod - def is_active(): - return torch.xpu.is_available() diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index f41bb89194..185930c16f 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -2,18 +2,11 @@ import itertools import os from typing import Any, Dict, List - - -def synchronize(): - import torch - if torch.cuda.is_available(): - torch.cuda.synchronize() - elif torch.xpu.is_available(): - torch.xpu.synchronize() +from triton.testing import do_bench as triton_do_bench def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean", - device="xpu", sync_submitting=True): + device="xpu"): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -33,69 +26,10 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu """ assert return_mode in ["min", "max", "mean", "median"] import torch - from torch.autograd.profiler import record_function - - fn() - synchronize() - - # We maintain a buffer of 256 MB that we clear - # before each kernel call to make sure that the L2 - # doesn't contain any input data before the run - cache_size = 256 * 1024 * 1024 - if fast_flush: - cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device) - else: - cache = torch.empty(int(cache_size), dtype=torch.int8, device=device) - - # Estimate the runtime of the function - start_event = torch.xpu.Event(enable_timing=True) - end_event = torch.xpu.Event(enable_timing=True) - start_event.record() - for _ in range(5): - cache.zero_() - fn() - end_event.record() - synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - - # compute number of warmup and repeat - n_warmup = max(1, int(warmup / estimate_ms)) - n_repeat = max(1, int(rep / estimate_ms)) - # Warm-up - for _ in range(n_warmup): - fn() - # Benchmark - with torch.autograd.profiler_legacy.profile(True, use_xpu=True) as prof: - for _ in range(n_repeat): - # we don't want `fn` to accumulate gradient values - # if it contains a backward pass. So we clear the - # provided gradients - if grad_to_none is not None: - for x in grad_to_none: - x.grad = None - # we clear the L2 cache before each run - cache.zero_() - if sync_submitting: - synchronize() - # record time of `fn` - with record_function("__profile_kernel_of_func"): - fn() - # Record clocks - synchronize() - - profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), prof.function_events) - functions = list(profiling_func_filter) - - def extract_kernels(funcs): - kernels = [] - kernels += list(itertools.chain.from_iterable(map(lambda func: extract_kernels(func.cpu_children), funcs))) - kernels += list(itertools.chain.from_iterable([func.kernels for func in funcs])) - return kernels - kernels = [extract_kernels(func.cpu_children) for func in functions] - assert len(kernels) == n_repeat, "the profiling number not match" - # Make the time to the milliseconds. - times = torch.tensor([sum([k.duration for k in ks]) * 1e-3 for ks in kernels], dtype=torch.float) + times = triton_do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, fast_flush=fast_flush, + return_mode="all", device_type=device) + times = torch.tensor(times, dtype=torch.float) if quantiles is not None: ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() if times.numel() > 2: diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index d409996f3f..aede62ba02 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -3,11 +3,9 @@ import triton import triton.language as tl - -import triton_kernels_benchmark from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module -benchmark_suit = triton_kernels_benchmark # triton.testing +from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark # pylint: disable=unused-argument @@ -184,8 +182,8 @@ def forward(q, k, v, causal, sm_scale): return o -@benchmark_suit.perf_report( - benchmark_suit.Benchmark( +@perf_report( + Benchmark( # argument names to use as an x-axis for the plot x_names=['Z', 'H', 'N_CTX', 'D_HEAD'], x_vals=[ # @@ -219,7 +217,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( + _, min_ms, max_ms, mean, cv = do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= False, scale=sm_scale), warmup=10, rep=10, quantiles=quantiles, fast_flush=False) @@ -229,15 +227,13 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') + _, min_ms, max_ms, mean, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) elif provider == 'xetla': func = getattr(xetla_kernel, 'flash_attn') xetla_fn = lambda: func(Z, H, D_HEAD, N_CTX, N_CTX) - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 4a229a6b9e..7d5ab18472 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -13,11 +13,9 @@ import triton import triton.language as tl from triton.runtime import driver - -import triton_kernels_benchmark from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module -benchmark_suit = triton_kernels_benchmark # triton.testing +from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark @torch.jit.script @@ -104,43 +102,41 @@ def softmax(x): return y -@benchmark_suit.perf_report( - benchmark_suit.Benchmark( - x_names=["N"], # argument names to use as an x-axis for the plot - x_vals=[256, 1024, 2048, 4096, 1024 * 8, 1024 * 16, 1024 * 32], # different possible values for `x_name` - line_arg="provider", # argument name whose value corresponds to a different line in the plot - line_vals=[ - "triton", - # "torch-native", - # "torch-jit", - "xetla", - ], # possible values for `line_arg`` - line_names=[ - "Triton", - # "Torch (native)", - # "Torch (jit)", - "XeTLA", - ], # label name for the lines - styles=[("blue", "-"), ("green", "-"), ("green", "--"), ("black", ":")], # line styles - ylabel=["GB/s", "TFlops"], # label name for the y-axis - plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. - args={"M": 4096}, # values for function arguments not in `x_names` and `y_name` - )) +@perf_report( + Benchmark(x_names=["N"], # argument names to use as an x-axis for the plot + x_vals=[256, 1024, 2048, 4096, 1024 * 8, 1024 * 16, 1024 * 32], # different possible values for `x_name` + line_arg="provider", # argument name whose value corresponds to a different line in the plot + line_vals=[ + "triton", + # "torch-native", + # "torch-jit", + "xetla", + ], # possible values for `line_arg`` + line_names=[ + "Triton", + # "Torch (native)", + # "Torch (jit)", + "XeTLA", + ], # label name for the lines + styles=[("blue", "-"), ("green", "-"), ("green", "--"), ("black", ":")], # line styles + ylabel=["GB/s", "TFlops"], # label name for the y-axis + plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. + args={"M": 4096}, # values for function arguments not in `x_names` and `y_name` + )) def benchmark(M, N, provider): x = torch.randn(M, N, device="xpu", dtype=torch.bfloat16) quantiles = [0.5, 0.0, 1.0] if provider == "torch-native": - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, - warmup=10, rep=10) + _, min_ms, max_ms, mean, cv = do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, warmup=10, + rep=10) if provider == "triton": triton_fn = lambda: softmax(x) torch_fn = lambda: torch.softmax(x, axis=-1) - benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10) + assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") + _, min_ms, max_ms, mean, cv = do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10) elif provider == "torch-jit": - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, - rep=10) + _, min_ms, max_ms, mean, cv = do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, rep=10) elif provider == "xetla": name = f"softmax_shape_{M}_{N}" @@ -148,7 +144,7 @@ def benchmark(M, N, provider): xetla_fn = lambda: func(x, 0) torch_fn = lambda: torch.softmax(x, axis=-1) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch") - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10) + _, min_ms, max_ms, mean, cv = do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10) else: raise NotImplementedError(f"Unsupported provider {provider}") diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index b749a5c79b..da57019558 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -12,10 +12,10 @@ import triton import triton.language as tl - -import triton_kernels_benchmark as benchmark_suit from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module +from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark + @triton.autotune( configs=[ @@ -199,8 +199,8 @@ def matmul(a, b): # Benchmark Performance -@benchmark_suit.perf_report( - benchmark_suit.Benchmark( +@perf_report( + Benchmark( # argument names to use as an x-axis for the plot x_names=['B', 'M', 'K', 'N'], # different possible values for `x_name` @@ -249,15 +249,14 @@ def benchmark(B, M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) elif provider == 'triton': triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + _, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) elif provider == 'xetla': if B == 1: c = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -272,8 +271,7 @@ def benchmark(B, M, N, K, provider): xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index f229f1b546..0f66eb193d 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -13,7 +13,7 @@ import triton import triton.language as tl -import triton_kernels_benchmark as benchmark_suit +from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark @triton.autotune( @@ -204,8 +204,8 @@ def matmul(a, b): # Benchmark Performance -@benchmark_suit.perf_report( - benchmark_suit.Benchmark( +@perf_report( + Benchmark( # argument names to use as an x-axis for the plot x_names=['B', 'M', 'K', 'N'], # different possible values for `x_name` @@ -259,9 +259,8 @@ def benchmark(B, M, N, K, provider): triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + _, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index ea7e45d988..f65661c00e 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -4,9 +4,7 @@ import triton import triton.language as tl -import triton_kernels_benchmark - -benchmark_suit = triton_kernels_benchmark # triton.testing +from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark @triton.autotune( @@ -128,8 +126,8 @@ def forward(ctx, a, b, acc_dtype=None, output_dtype=None): # Benchmark Performance -@benchmark_suit.perf_report( - benchmark_suit.Benchmark( +@perf_report( + Benchmark( # argument names to use as an x-axis for the plot x_names=['M', 'K', 'N'], x_vals=[ @@ -158,15 +156,14 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + _, min_ms, max_ms, mean, cv = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) elif provider == 'triton': triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + _, min_ms, max_ms, mean, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 66019bdb25..4d34e361d4 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -11,9 +11,7 @@ import triton import triton.language as tl -import triton_kernels_benchmark - -benchmark_suit = triton_kernels_benchmark # triton.testing +from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark # pylint: disable=unused-argument @@ -248,8 +246,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor): # Benchmark Performance -@benchmark_suit.perf_report( - benchmark_suit.Benchmark( +@perf_report( + Benchmark( # argument names to use as an x-axis for the plot x_names=['M', 'K', 'N'], x_vals=[[3072, 4096, 3072]], @@ -274,14 +272,13 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, - quantiles=quantiles, fast_flush=False) + _, min_ms, max_ms, mean, cv = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) elif provider == 'triton': triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) - benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') + _, min_ms, max_ms, mean, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/xetla_kernel/CMakeLists.txt b/benchmarks/xetla_kernel/CMakeLists.txt index 02e5e79475..0130e95666 100644 --- a/benchmarks/xetla_kernel/CMakeLists.txt +++ b/benchmarks/xetla_kernel/CMakeLists.txt @@ -33,9 +33,7 @@ target_compile_options(xetla_kernel PRIVATE "-fsycl") target_compile_options(xetla_kernel PUBLIC "-DXETPP_NEW_XMAIN") target_link_options(xetla_kernel PRIVATE ${XETLA_KERNEL_FLAGS}) target_link_libraries(xetla_kernel PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) -target_link_libraries(xetla_kernel PUBLIC "${TORCH_IPEX_LIBRARIES}") target_include_directories(xetla_kernel PUBLIC "${PYTHON_INCLUDE_DIRS}") -target_include_directories(xetla_kernel PUBLIC "${TORCH_IPEX_INCLUDE_DIRS}") target_include_directories(xetla_kernel PUBLIC "${XeTLALibrary_INCLUDE_DIR}") add_subdirectory(softmax) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 574c1a4e80..7270376bd9 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -4,8 +4,8 @@ #include "stream_k_gemm/stream_k_gemm.h" #include #include +#include #include -#include #include sycl::queue get_current_sycl_queue() { @@ -13,7 +13,10 @@ sycl::queue get_current_sycl_queue() { c10::impl::VirtualGuardImpl impl(at::DeviceType::XPU); c10::Stream stream = impl.getStream(impl.getDevice()); - return xpu::get_queue_from_stream(stream); + auto xpu_stream = c10::xpu::XPUStream(stream); + auto queue = xpu_stream.queue(); + + return queue; } #define CHECK_XPU(x) \ @@ -33,7 +36,6 @@ at::Tensor softmax(const at::Tensor &input, const int64_t dim) { auto queue = get_current_sycl_queue(); auto evt = softmax_forward(input.data_ptr(), output.data_ptr(), queue); - xpu::profiler_record("xetla kernel", evt); return output; } @@ -50,7 +52,6 @@ at::Tensor bf16_gemm(const at::Tensor &a, const at::Tensor &b, auto queue = get_current_sycl_queue(); auto evt = gemm_run(a.data_ptr(), b.data_ptr(), c.data_ptr(), acc.data_ptr(), cnt.data_ptr(), queue); - xpu::profiler_record("xetla kernel", evt); return acc; } @@ -66,7 +67,6 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b, auto queue = get_current_sycl_queue(); auto evt = stream_k_gemm_run(a.data_ptr(), b.data_ptr(), c.data_ptr(), acc.data_ptr(), cnt.data_ptr(), queue); - xpu::profiler_record("xetla kernel", evt); return acc; } @@ -100,7 +100,6 @@ void flash_attn(const int64_t num_batches, const int64_t num_heads, << "\n"; } - xpu::profiler_record("xetla kernel", evt); return; } From e60f1f40944019a0140649940b6c0728df4adc13 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Fri, 6 Sep 2024 21:46:00 +0000 Subject: [PATCH 02/13] add 'USE_IPEX' compilation option Signed-off-by: Anatoly Myachev --- benchmarks/CMakeLists.txt | 7 +++++++ benchmarks/setup.py | 7 ++++++- benchmarks/xetla_kernel/CMakeLists.txt | 5 +++++ benchmarks/xetla_kernel/python_main.cpp | 11 ++++++++++- 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index 473ca8cd04..7d1c59ea6f 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -4,6 +4,8 @@ set(CMAKE_CXX_STANDARD 20) project(TritonBenchmark) +option(USE_IPEX "Use IPEX" ON) + if(NOT WIN32) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") endif() @@ -12,6 +14,11 @@ find_package(Python3 COMPONENTS Interpreter) find_package(Torch REQUIRED) find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") +if(USE_IPEX) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_IPEX") + find_package(IPEX REQUIRED) +endif() + # add the XeTLA kernel. add_subdirectory(xetla_kernel) diff --git a/benchmarks/setup.py b/benchmarks/setup.py index 46b30c37ab..03246a995a 100644 --- a/benchmarks/setup.py +++ b/benchmarks/setup.py @@ -9,6 +9,11 @@ import torch +ipex_cmake_prefix_path = "" +if os.getenv("USE_IPEX", "1") == "1": + import intel_extension_for_pytorch + ipex_cmake_prefix_path = f";{intel_extension_for_pytorch.cmake_prefix_path}" + class CMakeBuild(): @@ -42,7 +47,7 @@ def build_extension(self): "Ninja", # Ninja is much faster than make "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path - f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}", + f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}{ipex_cmake_prefix_path}", "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=" + self.extdir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + self.extdir, diff --git a/benchmarks/xetla_kernel/CMakeLists.txt b/benchmarks/xetla_kernel/CMakeLists.txt index 0130e95666..64334dc518 100644 --- a/benchmarks/xetla_kernel/CMakeLists.txt +++ b/benchmarks/xetla_kernel/CMakeLists.txt @@ -36,6 +36,11 @@ target_link_libraries(xetla_kernel PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBR target_include_directories(xetla_kernel PUBLIC "${PYTHON_INCLUDE_DIRS}") target_include_directories(xetla_kernel PUBLIC "${XeTLALibrary_INCLUDE_DIR}") +if(USE_IPEX) + target_link_libraries(xetla_kernel PUBLIC "${TORCH_IPEX_LIBRARIES}") + target_include_directories(xetla_kernel PUBLIC "${TORCH_IPEX_INCLUDE_DIRS}") +endif() + add_subdirectory(softmax) add_subdirectory(gemm) add_subdirectory(stream_k_gemm) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 7270376bd9..9720cb89b8 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -4,17 +4,26 @@ #include "stream_k_gemm/stream_k_gemm.h" #include #include -#include #include #include +#ifdef USE_IPEX +#include +#else +#include +#endif + sycl::queue get_current_sycl_queue() { // submit kernel c10::impl::VirtualGuardImpl impl(at::DeviceType::XPU); c10::Stream stream = impl.getStream(impl.getDevice()); +#ifdef USE_IPEX + auto queue = xpu::get_queue_from_stream(stream); +#else auto xpu_stream = c10::xpu::XPUStream(stream); auto queue = xpu_stream.queue(); +#endif return queue; } From b7ddbc936eea0c43518a6e216f8ace8ca46937a4 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Sat, 7 Sep 2024 12:25:08 +0000 Subject: [PATCH 03/13] fix and revert unnecessary changes Signed-off-by: Anatoly Myachev --- .../triton_kernels_benchmark/__init__.py | 1 + .../flash_attention_fwd_benchmark.py | 18 +++--- .../triton_kernels_benchmark/fused_softmax.py | 60 ++++++++++--------- .../gemm_benchmark.py | 20 ++++--- .../gemm_preop_exp_benchmark.py | 11 ++-- .../gemm_splitk_benchmark.py | 17 +++--- .../gemm_streamk_benchmark.py | 17 +++--- benchmarks/xetla_kernel/python_main.cpp | 4 +- 8 files changed, 84 insertions(+), 64 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/__init__.py b/benchmarks/triton_kernels_benchmark/__init__.py index e69de29bb2..576e9dcc16 100644 --- a/benchmarks/triton_kernels_benchmark/__init__.py +++ b/benchmarks/triton_kernels_benchmark/__init__.py @@ -0,0 +1 @@ +from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark # type: ignore # noqa: F401 diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index aede62ba02..d409996f3f 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -3,9 +3,11 @@ import triton import triton.language as tl + +import triton_kernels_benchmark from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module -from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark +benchmark_suit = triton_kernels_benchmark # triton.testing # pylint: disable=unused-argument @@ -182,8 +184,8 @@ def forward(q, k, v, causal, sm_scale): return o -@perf_report( - Benchmark( +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot x_names=['Z', 'H', 'N_CTX', 'D_HEAD'], x_vals=[ # @@ -217,7 +219,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): sm_scale = 0.125 quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = do_bench( + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench( lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal= False, scale=sm_scale), warmup=10, rep=10, quantiles=quantiles, fast_flush=False) @@ -227,13 +229,15 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32) atol = 1e-1 if N_CTX == 16384 else 1e-2 - assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) + benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) elif provider == 'xetla': func = getattr(xetla_kernel, 'flash_attn') xetla_fn = lambda: func(Z, H, D_HEAD, N_CTX, N_CTX) - _, min_ms, max_ms, mean, cv = do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 7d5ab18472..4a229a6b9e 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -13,9 +13,11 @@ import triton import triton.language as tl from triton.runtime import driver + +import triton_kernels_benchmark from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module -from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark +benchmark_suit = triton_kernels_benchmark # triton.testing @torch.jit.script @@ -102,41 +104,43 @@ def softmax(x): return y -@perf_report( - Benchmark(x_names=["N"], # argument names to use as an x-axis for the plot - x_vals=[256, 1024, 2048, 4096, 1024 * 8, 1024 * 16, 1024 * 32], # different possible values for `x_name` - line_arg="provider", # argument name whose value corresponds to a different line in the plot - line_vals=[ - "triton", - # "torch-native", - # "torch-jit", - "xetla", - ], # possible values for `line_arg`` - line_names=[ - "Triton", - # "Torch (native)", - # "Torch (jit)", - "XeTLA", - ], # label name for the lines - styles=[("blue", "-"), ("green", "-"), ("green", "--"), ("black", ":")], # line styles - ylabel=["GB/s", "TFlops"], # label name for the y-axis - plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. - args={"M": 4096}, # values for function arguments not in `x_names` and `y_name` - )) +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( + x_names=["N"], # argument names to use as an x-axis for the plot + x_vals=[256, 1024, 2048, 4096, 1024 * 8, 1024 * 16, 1024 * 32], # different possible values for `x_name` + line_arg="provider", # argument name whose value corresponds to a different line in the plot + line_vals=[ + "triton", + # "torch-native", + # "torch-jit", + "xetla", + ], # possible values for `line_arg`` + line_names=[ + "Triton", + # "Torch (native)", + # "Torch (jit)", + "XeTLA", + ], # label name for the lines + styles=[("blue", "-"), ("green", "-"), ("green", "--"), ("black", ":")], # line styles + ylabel=["GB/s", "TFlops"], # label name for the y-axis + plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. + args={"M": 4096}, # values for function arguments not in `x_names` and `y_name` + )) def benchmark(M, N, provider): x = torch.randn(M, N, device="xpu", dtype=torch.bfloat16) quantiles = [0.5, 0.0, 1.0] if provider == "torch-native": - _, min_ms, max_ms, mean, cv = do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, warmup=10, - rep=10) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, + warmup=10, rep=10) if provider == "triton": triton_fn = lambda: softmax(x) torch_fn = lambda: torch.softmax(x, axis=-1) - assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") - _, min_ms, max_ms, mean, cv = do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10) + benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10) elif provider == "torch-jit": - _, min_ms, max_ms, mean, cv = do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, rep=10) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10, + rep=10) elif provider == "xetla": name = f"softmax_shape_{M}_{N}" @@ -144,7 +148,7 @@ def benchmark(M, N, provider): xetla_fn = lambda: func(x, 0) torch_fn = lambda: torch.softmax(x, axis=-1) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch") - _, min_ms, max_ms, mean, cv = do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10) else: raise NotImplementedError(f"Unsupported provider {provider}") diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index da57019558..b749a5c79b 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -12,9 +12,9 @@ import triton import triton.language as tl -from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module -from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark +import triton_kernels_benchmark as benchmark_suit +from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module @triton.autotune( @@ -199,8 +199,8 @@ def matmul(a, b): # Benchmark Performance -@perf_report( - Benchmark( +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot x_names=['B', 'M', 'K', 'N'], # different possible values for `x_name` @@ -249,14 +249,15 @@ def benchmark(B, M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, + quantiles=quantiles, fast_flush=False) elif provider == 'triton': triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) + benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) elif provider == 'xetla': if B == 1: c = torch.empty((M, N), device='xpu', dtype=torch.float32) @@ -271,7 +272,8 @@ def benchmark(B, M, N, K, provider): xetla_fn = lambda: func(a, b, c, acc, cnt) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch') - _, min_ms, max_ms, mean_ms, cv = do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 0f66eb193d..f229f1b546 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -13,7 +13,7 @@ import triton import triton.language as tl -from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark +import triton_kernels_benchmark as benchmark_suit @triton.autotune( @@ -204,8 +204,8 @@ def matmul(a, b): # Benchmark Performance -@perf_report( - Benchmark( +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot x_names=['B', 'M', 'K', 'N'], # different possible values for `x_name` @@ -259,8 +259,9 @@ def benchmark(B, M, N, K, provider): triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean_ms, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) + benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + _, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index f65661c00e..ea7e45d988 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -4,7 +4,9 @@ import triton import triton.language as tl -from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark +import triton_kernels_benchmark + +benchmark_suit = triton_kernels_benchmark # triton.testing @triton.autotune( @@ -126,8 +128,8 @@ def forward(ctx, a, b, acc_dtype=None, output_dtype=None): # Benchmark Performance -@perf_report( - Benchmark( +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot x_names=['M', 'K', 'N'], x_vals=[ @@ -156,14 +158,15 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, + quantiles=quantiles, fast_flush=False) elif provider == 'triton': triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3 - assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) + benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 4d34e361d4..66019bdb25 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -11,7 +11,9 @@ import triton import triton.language as tl -from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark +import triton_kernels_benchmark + +benchmark_suit = triton_kernels_benchmark # triton.testing # pylint: disable=unused-argument @@ -246,8 +248,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor): # Benchmark Performance -@perf_report( - Benchmark( +@benchmark_suit.perf_report( + benchmark_suit.Benchmark( # argument names to use as an x-axis for the plot x_names=['M', 'K', 'N'], x_vals=[[3072, 4096, 3072]], @@ -272,13 +274,14 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.0, 1.0] if provider == 'onednn': - _, min_ms, max_ms, mean, cv = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, quantiles=quantiles, - fast_flush=False) + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10, + quantiles=quantiles, fast_flush=False) elif provider == 'triton': triton_fn = lambda: matmul(a, b) torch_fn = lambda: torch.matmul(a, b).to(torch.float32) - assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') - _, min_ms, max_ms, mean, cv = do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) + benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch') + _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles, + fast_flush=False) else: raise NotImplementedError(f'Unsupported provider {provider}') diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 9720cb89b8..9bcc4fe98a 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -5,14 +5,16 @@ #include #include #include -#include #ifdef USE_IPEX +// `#include ` should be before `#include ` #include #else #include #endif +#include + sycl::queue get_current_sycl_queue() { // submit kernel c10::impl::VirtualGuardImpl impl(at::DeviceType::XPU); From b1c5467aa7360f189c2e4b1a3085f93bd2ae9bfa Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Sat, 7 Sep 2024 15:10:37 +0000 Subject: [PATCH 04/13] fix Signed-off-by: Anatoly Myachev --- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 2 +- benchmarks/triton_kernels_benchmark/fused_softmax.py | 2 +- benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index d409996f3f..f9ec425f4c 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -5,7 +5,7 @@ import triton.language as tl import triton_kernels_benchmark -from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module +import xetla_kernel benchmark_suit = triton_kernels_benchmark # triton.testing diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 4a229a6b9e..8a1bcbbd96 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -15,7 +15,7 @@ from triton.runtime import driver import triton_kernels_benchmark -from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module +import xetla_kernel benchmark_suit = triton_kernels_benchmark # triton.testing diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index b749a5c79b..89ba8a8430 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -14,7 +14,7 @@ import triton.language as tl import triton_kernels_benchmark as benchmark_suit -from triton_kernels_benchmark import xetla_kernel # pylint: disable=no-name-in-module +import xetla_kernel @triton.autotune( From d2614e208c0f08c5e8b6d6a78302e935deaf90b1 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Sun, 8 Sep 2024 13:12:17 +0000 Subject: [PATCH 05/13] remove remaining stuff Signed-off-by: Anatoly Myachev --- benchmarks/xetla_kernel/python_main.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 9bcc4fe98a..2440e937eb 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -7,14 +7,12 @@ #include #ifdef USE_IPEX -// `#include ` should be before `#include ` +#include #include #else #include #endif -#include - sycl::queue get_current_sycl_queue() { // submit kernel c10::impl::VirtualGuardImpl impl(at::DeviceType::XPU); @@ -41,7 +39,6 @@ sycl::queue get_current_sycl_queue() { template at::Tensor softmax(const at::Tensor &input, const int64_t dim) { CHECK_INPUT(input); - RECORD_FUNCTION("xetla softmax", {input}); auto output = at::empty_like(input); @@ -58,7 +55,6 @@ at::Tensor bf16_gemm(const at::Tensor &a, const at::Tensor &b, CHECK_INPUT(b); CHECK_INPUT(c); CHECK_INPUT(acc); - RECORD_FUNCTION("xetla gemm", {a, b, c, acc}); auto queue = get_current_sycl_queue(); auto evt = gemm_run(a.data_ptr(), b.data_ptr(), c.data_ptr(), @@ -73,7 +69,6 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b, CHECK_INPUT(b); CHECK_INPUT(c); CHECK_INPUT(acc); - RECORD_FUNCTION("xetla stream_k_gemm", {a, b, c, acc}); auto queue = get_current_sycl_queue(); auto evt = stream_k_gemm_run(a.data_ptr(), b.data_ptr(), c.data_ptr(), @@ -90,8 +85,6 @@ template Date: Sun, 8 Sep 2024 13:17:45 +0000 Subject: [PATCH 06/13] fix Signed-off-by: Anatoly Myachev --- benchmarks/xetla_kernel/python_main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 2440e937eb..25e508cb5c 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -7,8 +7,8 @@ #include #ifdef USE_IPEX -#include #include +#include #else #include #endif From dedd201d5e4d3e44c5fe186c690813c965e58180 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Sun, 8 Sep 2024 13:55:56 +0000 Subject: [PATCH 07/13] Revert changes for disabling IPEX Signed-off-by: Anatoly Myachev --- benchmarks/CMakeLists.txt | 8 +------- benchmarks/setup.py | 8 ++------ benchmarks/xetla_kernel/CMakeLists.txt | 7 ++----- benchmarks/xetla_kernel/python_main.cpp | 14 +------------- 4 files changed, 6 insertions(+), 31 deletions(-) diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index 7d1c59ea6f..03cb418132 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -4,8 +4,6 @@ set(CMAKE_CXX_STANDARD 20) project(TritonBenchmark) -option(USE_IPEX "Use IPEX" ON) - if(NOT WIN32) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") endif() @@ -13,11 +11,7 @@ endif() find_package(Python3 COMPONENTS Interpreter) find_package(Torch REQUIRED) find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") - -if(USE_IPEX) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_IPEX") - find_package(IPEX REQUIRED) -endif() +find_package(IPEX REQUIRED) # add the XeTLA kernel. diff --git a/benchmarks/setup.py b/benchmarks/setup.py index 03246a995a..08f76e21a7 100644 --- a/benchmarks/setup.py +++ b/benchmarks/setup.py @@ -8,11 +8,7 @@ from setuptools import setup import torch - -ipex_cmake_prefix_path = "" -if os.getenv("USE_IPEX", "1") == "1": - import intel_extension_for_pytorch - ipex_cmake_prefix_path = f";{intel_extension_for_pytorch.cmake_prefix_path}" +import intel_extension_for_pytorch class CMakeBuild(): @@ -47,7 +43,7 @@ def build_extension(self): "Ninja", # Ninja is much faster than make "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path - f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}{ipex_cmake_prefix_path}", + f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path};{intel_extension_for_pytorch.cmake_prefix_path}", "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=" + self.extdir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + self.extdir, diff --git a/benchmarks/xetla_kernel/CMakeLists.txt b/benchmarks/xetla_kernel/CMakeLists.txt index 64334dc518..02e5e79475 100644 --- a/benchmarks/xetla_kernel/CMakeLists.txt +++ b/benchmarks/xetla_kernel/CMakeLists.txt @@ -33,14 +33,11 @@ target_compile_options(xetla_kernel PRIVATE "-fsycl") target_compile_options(xetla_kernel PUBLIC "-DXETPP_NEW_XMAIN") target_link_options(xetla_kernel PRIVATE ${XETLA_KERNEL_FLAGS}) target_link_libraries(xetla_kernel PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) +target_link_libraries(xetla_kernel PUBLIC "${TORCH_IPEX_LIBRARIES}") target_include_directories(xetla_kernel PUBLIC "${PYTHON_INCLUDE_DIRS}") +target_include_directories(xetla_kernel PUBLIC "${TORCH_IPEX_INCLUDE_DIRS}") target_include_directories(xetla_kernel PUBLIC "${XeTLALibrary_INCLUDE_DIR}") -if(USE_IPEX) - target_link_libraries(xetla_kernel PUBLIC "${TORCH_IPEX_LIBRARIES}") - target_include_directories(xetla_kernel PUBLIC "${TORCH_IPEX_INCLUDE_DIRS}") -endif() - add_subdirectory(softmax) add_subdirectory(gemm) add_subdirectory(stream_k_gemm) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 25e508cb5c..193db6f142 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -5,27 +5,15 @@ #include #include #include - -#ifdef USE_IPEX #include #include -#else -#include -#endif sycl::queue get_current_sycl_queue() { // submit kernel c10::impl::VirtualGuardImpl impl(at::DeviceType::XPU); c10::Stream stream = impl.getStream(impl.getDevice()); -#ifdef USE_IPEX - auto queue = xpu::get_queue_from_stream(stream); -#else - auto xpu_stream = c10::xpu::XPUStream(stream); - auto queue = xpu_stream.queue(); -#endif - - return queue; + return xpu::get_queue_from_stream(stream); } #define CHECK_XPU(x) \ From c5aefdcddcb19fe5bbc97a81c8e407d010b12b70 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 9 Sep 2024 14:17:42 +0000 Subject: [PATCH 08/13] DEBUG: emulate 'sync_submitting' Signed-off-by: Anatoly Myachev --- python/triton/testing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/triton/testing.py b/python/triton/testing.py index 07827ad853..2c739fffd9 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -209,6 +209,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu x.grad = None # we clear the L2 cache before each run cache.zero_() + # emulate `sync_submitting` to check influence + di.synchronize() # record time of `fn` start_event[i].record() fn() From 10151f6152215d43e42e14f051dd00730c4f42a8 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 9 Sep 2024 23:30:50 +0000 Subject: [PATCH 09/13] Revert "DEBUG: emulate 'sync_submitting'" This reverts commit c5aefdcddcb19fe5bbc97a81c8e407d010b12b70. --- python/triton/testing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index 2c739fffd9..07827ad853 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -209,8 +209,6 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu x.grad = None # we clear the L2 cache before each run cache.zero_() - # emulate `sync_submitting` to check influence - di.synchronize() # record time of `fn` start_event[i].record() fn() From 89c40df2c9c17d9f1b224b6cee2fc3c81f5d2c89 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 9 Sep 2024 23:33:23 +0000 Subject: [PATCH 10/13] Try reduce overhea dwhile using elaped_time profiling method Signed-off-by: Anatoly Myachev --- .../flash_attention_fwd_benchmark.py | 11 ++++- .../flash_attention/fmha_forward_v5.h | 41 ++++++++++++------- benchmarks/xetla_kernel/python_main.cpp | 22 ++++++++-- 3 files changed, 54 insertions(+), 20 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index f9ec425f4c..f1a0d1b6ab 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -235,7 +235,16 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider): elif provider == 'xetla': func = getattr(xetla_kernel, 'flash_attn') - xetla_fn = lambda: func(Z, H, D_HEAD, N_CTX, N_CTX) + out = torch.empty_like(q, device='xpu', dtype=dtype) + size_score = Z * H * N_CTX * N_CTX + size_attn_mask = Z * N_CTX * N_CTX + dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8) + bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype) + size_ml = Z * H * N_CTX + m = torch.empty((size_ml, ), device='xpu', dtype=torch.float) + l = torch.empty((size_ml, ), device='xpu', dtype=torch.float) + + xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX) _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles, fast_flush=False) diff --git a/benchmarks/xetla_kernel/flash_attention/fmha_forward_v5.h b/benchmarks/xetla_kernel/flash_attention/fmha_forward_v5.h index f05bc83f25..cef509c03b 100644 --- a/benchmarks/xetla_kernel/flash_attention/fmha_forward_v5.h +++ b/benchmarks/xetla_kernel/flash_attention/fmha_forward_v5.h @@ -620,7 +620,9 @@ class FmhaForwardKernel; // The launcher of fmha forward kernel template -sycl::event fmha_forward_impl(sycl::queue &q, uint32_t num_batches, +sycl::event fmha_forward_impl(sycl::queue &q, void *_q, void *_k, void *_v, + void *_out, void *_dropout_mask, void *_bias, + void *_m, void *_l, uint32_t num_batches, uint32_t num_heads, uint32_t head_size, uint32_t num_queries, uint32_t num_keys, uint64_t seed = 0, uint64_t offset = 123) { @@ -642,14 +644,23 @@ sycl::event fmha_forward_impl(sycl::queue &q, uint32_t num_batches, uint32_t size_ml = shape.get_ml_size(); // forward - T *query = sycl::malloc_shared(size_query, q); - T *key = sycl::malloc_shared(size_key, q); - T *value = sycl::malloc_shared(size_key, q); - T *bias = sycl::malloc_shared(size_attn_mask, q); - uint8_t *dropout_mask = sycl::malloc_shared(size_score, q); - T *out = sycl::malloc_shared(size_query, q); - float *m = sycl::malloc_shared(size_ml, q); - float *l = sycl::malloc_shared(size_ml, q); + // T *query = sycl::malloc_shared(size_query, q); + // T *key = sycl::malloc_shared(size_key, q); + // T *value = sycl::malloc_shared(size_key, q); + T *query = static_cast(_q); + T *key = static_cast(_k); + T *value = static_cast(_v); + + // T *bias = sycl::malloc_shared(size_attn_mask, q); + T *bias = static_cast(_bias); + // uint8_t *dropout_mask = sycl::malloc_shared(size_score, q); + uint8_t *dropout_mask = static_cast(_dropout_mask); + // T *out = sycl::malloc_shared(size_query, q); + T *out = static_cast(_out); + // float *m = sycl::malloc_shared(size_ml, q); + float *m = static_cast(_m); + // float *l = sycl::malloc_shared(size_ml, q); + float *l = static_cast(_l); // fmha forward kernel using fmha_forward_op_t = @@ -676,12 +687,12 @@ sycl::event fmha_forward_impl(sycl::queue &q, uint32_t num_batches, fmha_fwd_op(ei, args); }); }); - sycl::free(query, q); - sycl::free(key, q); - sycl::free(value, q); - sycl::free(bias, q); - sycl::free(dropout_mask, q); - sycl::free(out, q); + // sycl::free(query, q); + // sycl::free(key, q); + // sycl::free(value, q); + // sycl::free(bias, q); + // sycl::free(dropout_mask, q); + // sycl::free(out, q); return event; } diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 193db6f142..2e3568fc64 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -66,13 +66,27 @@ at::Tensor bf16_stream_k_gemm(const at::Tensor &a, const at::Tensor &b, #define CALL_IMPL_ATTENTION_FUNC(P) \ fmha::fmha_forward_impl( \ - queue, num_batches, num_heads, head_size, num_queries, num_keys) + queue, q.data_ptr(), k.data_ptr(), v.data_ptr(), out.data_ptr(), \ + dropout_mask.data_ptr(), bias.data_ptr(), m.data_ptr(), l.data_ptr(), \ + num_batches, num_heads, head_size, num_queries, num_keys) template -void flash_attn(const int64_t num_batches, const int64_t num_heads, - const int64_t head_size, const int64_t num_queries, - const int64_t num_keys) { +void flash_attn(const at::Tensor &q, const at::Tensor &k, const at::Tensor &v, + const at::Tensor &out, const at::Tensor &dropout_mask, + const at::Tensor &bias, const at::Tensor &m, + const at::Tensor &l, const int64_t num_batches, + const int64_t num_heads, const int64_t head_size, + const int64_t num_queries, const int64_t num_keys) { + + CHECK_INPUT(q); + CHECK_INPUT(k); + CHECK_INPUT(v); + CHECK_INPUT(out); + CHECK_INPUT(dropout_mask); + CHECK_INPUT(bias); + CHECK_INPUT(m); + CHECK_INPUT(l); auto queue = get_current_sycl_queue(); From 9bbe8a97b7233d6aac1060f76a002823535725d3 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 9 Sep 2024 23:58:48 +0000 Subject: [PATCH 11/13] update softmax Signed-off-by: Anatoly Myachev --- benchmarks/triton_kernels_benchmark/fused_softmax.py | 10 +++++----- benchmarks/xetla_kernel/python_main.cpp | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 8a1bcbbd96..0fd95f6f2f 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -88,7 +88,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n MAX_WORK_GROUP_SIZE = properties["max_work_group_size"] -def softmax(x): +def softmax(x, y): n_rows, n_cols = x.shape # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` @@ -96,8 +96,6 @@ def softmax(x): BLOCK_SIZE_Y = MAX_WORK_GROUP_SIZE // BLOCK_SIZE_X BLOCK_SIZE_Y = BLOCK_SIZE_Y if BLOCK_SIZE_Y > 0 else 1 - # Allocate output - y = torch.empty_like(x) # Create a number of persistent programs. softmax_kernel[(n_rows // BLOCK_SIZE_Y, )](y, x, x.stride(0), y.stride(0), n_cols, BLOCK_SIZE_X=BLOCK_SIZE_X, BLOCK_SIZE_Y=BLOCK_SIZE_Y) @@ -133,7 +131,8 @@ def benchmark(M, N, provider): _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, warmup=10, rep=10) if provider == "triton": - triton_fn = lambda: softmax(x) + out = torch.empty_like(x, device="xpu") + triton_fn = lambda: softmax(x, out) torch_fn = lambda: torch.softmax(x, axis=-1) benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch") _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10) @@ -145,7 +144,8 @@ def benchmark(M, N, provider): elif provider == "xetla": name = f"softmax_shape_{M}_{N}" func = getattr(xetla_kernel, name) - xetla_fn = lambda: func(x, 0) + out = torch.empty_like(x, device="xpu") + xetla_fn = lambda: func(x, out, 0) torch_fn = lambda: torch.softmax(x, axis=-1) # benchmark_suit.assert_close(xetla_fn(), torch_fn(), err_msg="xetla to torch") _, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 2e3568fc64..c30c455fc2 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -25,10 +25,10 @@ sycl::queue get_current_sycl_queue() { CHECK_CONTIGUOUS(x) template -at::Tensor softmax(const at::Tensor &input, const int64_t dim) { +at::Tensor softmax(const at::Tensor &input, const at::Tensor &output, + const int64_t dim) { CHECK_INPUT(input); - - auto output = at::empty_like(input); + CHECK_INPUT(output); auto queue = get_current_sycl_queue(); auto evt = softmax_forward(input.data_ptr(), output.data_ptr(), queue); From f6726e4bd65226182a23f92cc735ea0cea29974a Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 19 Sep 2024 19:39:01 +0000 Subject: [PATCH 12/13] fix after merge Signed-off-by: Anatoly Myachev --- benchmarks/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/setup.py b/benchmarks/setup.py index b88f875fc1..46b30c37ab 100644 --- a/benchmarks/setup.py +++ b/benchmarks/setup.py @@ -42,7 +42,7 @@ def build_extension(self): "Ninja", # Ninja is much faster than make "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path - f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}{ipex_cmake_prefix_path}", + f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}", "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY=" + self.extdir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + self.extdir, From 3754c5e9e99386d593f4dc8e6689a7b39dbe83c7 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 19 Sep 2024 20:53:40 +0000 Subject: [PATCH 13/13] remove USE_IPEX Signed-off-by: Anatoly Myachev --- .github/workflows/triton-benchmarks.yml | 1 - benchmarks/CMakeLists.txt | 8 - .../float_conversion/float_conversion.py | 5 - .../benchmark_driver.py | 434 ------------------ benchmarks/xetla_kernel/CMakeLists.txt | 11 +- benchmarks/xetla_kernel/python_main.cpp | 12 +- scripts/test-triton.sh | 2 +- 7 files changed, 3 insertions(+), 470 deletions(-) delete mode 100644 benchmarks/triton_kernels_benchmark/benchmark_driver.py diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index 9be17f897e..4ff2c56d99 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -22,7 +22,6 @@ permissions: read-all env: PYTHON_VERSION: "3.10" - USE_IPEX: ${{ github.event_name == 'schedule' && '1' || inputs.install_ipex && '1' || '0' }} jobs: build: diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index 7d1c59ea6f..0a4977aa21 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -4,8 +4,6 @@ set(CMAKE_CXX_STANDARD 20) project(TritonBenchmark) -option(USE_IPEX "Use IPEX" ON) - if(NOT WIN32) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") endif() @@ -14,11 +12,5 @@ find_package(Python3 COMPONENTS Interpreter) find_package(Torch REQUIRED) find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") -if(USE_IPEX) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_IPEX") - find_package(IPEX REQUIRED) -endif() - - # add the XeTLA kernel. add_subdirectory(xetla_kernel) diff --git a/benchmarks/micro_benchmarks/conversion/float_conversion/float_conversion.py b/benchmarks/micro_benchmarks/conversion/float_conversion/float_conversion.py index 03fd69a19f..84032bd9b6 100644 --- a/benchmarks/micro_benchmarks/conversion/float_conversion/float_conversion.py +++ b/benchmarks/micro_benchmarks/conversion/float_conversion/float_conversion.py @@ -1,12 +1,7 @@ -import os - import torch import triton import triton.language as tl -if os.getenv('USE_IPEX', '1') == '1': - import intel_extension_for_pytorch # type: ignore # noqa: F401 - @triton.jit def float_trunc_kernel( diff --git a/benchmarks/triton_kernels_benchmark/benchmark_driver.py b/benchmarks/triton_kernels_benchmark/benchmark_driver.py deleted file mode 100644 index aa1cbb41ae..0000000000 --- a/benchmarks/triton_kernels_benchmark/benchmark_driver.py +++ /dev/null @@ -1,434 +0,0 @@ -import os -import hashlib -import importlib.util -import tempfile -from pathlib import Path - -from triton.backends.compiler import GPUTarget -from triton.backends.driver import DriverBase -from triton.runtime.cache import get_cache_manager -from triton.runtime.build import _build, quiet - -import torch -import intel_extension_for_pytorch - -_dirname = os.getenv("ZE_PATH", default="/usr/local") - -include_dir = [ - os.path.join(_dirname, "include"), - os.path.join(torch.utils.cmake_prefix_path, "../../include"), - os.path.join(torch.utils.cmake_prefix_path, "../../include/torch/csrc/api/include"), - os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../include") -] - -oneapi_root = os.getenv("ONEAPI_ROOT") -if oneapi_root: - include_dir += [ - os.path.join(oneapi_root, "compiler/latest/include"), - os.path.join(oneapi_root, "compiler/latest/include/sycl") - ] - -library_dir = [ - os.path.join(_dirname, "lib"), - os.path.join(torch.utils.cmake_prefix_path, "../../lib"), - os.path.join(intel_extension_for_pytorch.cmake_prefix_path, "../../lib") -] -libraries = ["ze_loader", "sycl", "torch", "intel-ext-pt-gpu"] - - -def compile_module_from_src(src, name): - key = hashlib.sha256(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.cpp") - with open(src_path, "w", encoding="utf-8") as f: - f.write(src) - with quiet(): - so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) - spec = importlib.util.spec_from_file_location(name, cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod - - -# ------------------------ -# Utils -# ------------------------ - - -class XPUUtils: - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(XPUUtils, cls).__new__(cls) - return cls.instance - - def __init__(self): - dirname = os.path.dirname(os.path.realpath(__file__)) - mod = compile_module_from_src( - Path(os.path.join(dirname, "driver.c")).read_text(encoding="utf-8"), "spirv_utils") - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - self.context = mod.init_context(self.get_sycl_queue()) - self.device_count = mod.init_devices(self.get_sycl_queue()) - self.current_device = 0 if self.device_count[0] > 0 else -1 - - def get_current_device(self): - return self.current_device - - def get_sycl_queue(self): - return torch.xpu.current_stream().sycl_queue - - -# ------------------------ -# Launcher -# ------------------------ - - -def ty_to_cpp(ty): - if ty[0] == "*": - return "void*" - return { - "i1": "int32_t", - "i8": "int8_t", - "i16": "int16_t", - "i32": "int32_t", - "i64": "int64_t", - "u1": "uint32_t", - "u8": "uint8_t", - "u16": "uint16_t", - "u32": "uint32_t", - "u64": "uint64_t", - "fp16": "float", - "bf16": "float", - "fp32": "float", - "f32": "float", - "fp64": "double", - }[ty] - - -def make_launcher(constants, signature, ids): # pylint: disable=unused-argument - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors. - arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - - def _extracted_type(ty): - if ty[0] == "*": - return "PyObject*" - return ty_to_cpp(ty) - - def format_of(ty): - return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "l", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", - }[ty] - - args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()]) - fmt = "iiiOOOOOO" + args_format - args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else "" - - # generate glue code - src = f""" - #include - #include - #include - #include - #include - #include - #include - #include - - #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION - #include - #include - #include - - static inline void gpuAssert(ze_result_t code, const char *file, int line) - {{ - if (code != ZE_RESULT_SUCCESS) - {{ - const char* prefix = "Triton Error [ZE]: "; - std::string str = std::to_string(code); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str.c_str()); - PyErr_SetString(PyExc_RuntimeError, err); - }} - }} - - #define ZE_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} - - typedef struct _DevicePtrInfo {{ - void* dev_ptr; - bool valid; - }} DevicePtrInfo; - - static inline void checkDevicePointer(DevicePtrInfo *ptr_info, int idx, const sycl::queue &queue) {{ - if (!ptr_info->dev_ptr || !ptr_info->valid) {{ - return; - }} - auto context = queue.get_context(); - auto handle = sycl::get_native(context); - ze_memory_allocation_properties_t prop; - prop.stype = ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES; - prop.pNext = nullptr; - ze_device_handle_t device; - auto res = zeMemGetAllocProperties((ze_context_handle_t)handle, ptr_info->dev_ptr, &prop, &device); - if (res != ZE_RESULT_SUCCESS) {{ - PyErr_Format(PyExc_ValueError, - "Cannot get memory properties for pointer argument (at %d, err=%d)", idx, res); - ptr_info->valid = false; - }} else if (prop.type != ZE_MEMORY_TYPE_DEVICE) {{ - PyErr_Format(PyExc_ValueError, - "Pointer argument (at %d) doesn't reference XPU device memory (cpu tensor?)", idx); - ptr_info->valid = false; - }} - }} - - static inline DevicePtrInfo getPointer(PyObject *obj, int idx, const sycl::queue &queue) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = PyLong_AsVoidPtr(obj); - checkDevicePointer(&ptr_info, idx, queue); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = PyLong_AsVoidPtr(ret); - if(!ptr_info.dev_ptr) {{ - return ptr_info; - }} - checkDevicePointer(&ptr_info, idx, queue); - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; - }} -// start sycl - static void set_scalar_arg( - sycl::handler& cgh, - int index, - size_t size, - const void* value) {{ - switch (size) {{ - case sizeof(uint8_t): - cgh.set_arg(index, *static_cast(value)); - break; - case sizeof(uint16_t): - cgh.set_arg(index, *static_cast(value)); - break; - case sizeof(uint32_t): - cgh.set_arg(index, *static_cast(value)); - break; - case sizeof(uint64_t): - cgh.set_arg(index, *static_cast(value)); - break; - default: - assert(false && "wrong scalar size in sycl gen."); - }} - }} - static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {", " + arg_decls if len(arg_decls) > 0 else ""}) {{ - - std::string kernel_name = kernel_ptr.get_info(); - RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}}); - void *params[] = {{ {", ".join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; - uint32_t num_params = sizeof(params)/sizeof(params[0]); - uint32_t expected_num_params = kernel_ptr.get_info(); - size_t global_range_x = gridX*threads_per_warp*num_warps; - size_t global_range_y = gridY; - size_t global_range_z = gridZ; - size_t local_range_x = num_warps*threads_per_warp; - size_t local_range_y = 1; - size_t local_range_z = 1; - sycl::range<3> global_range(global_range_z, global_range_y, global_range_x); - sycl::range<3> local_range(local_range_z, local_range_y, local_range_x); - sycl::nd_range<3> parallel_work_size(global_range, local_range); - if (shared_memory) {{ - expected_num_params -= 1; - }} - assert(num_params == expected_num_params && "number of kernel param not matched"); - // Submit the imported kernel. - auto cgf = [&](sycl::handler &cgh) {{ - {" ".join(f"set_scalar_arg(cgh, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants]))} - if (shared_memory) {{ - using share_mem_t = sycl::local_accessor; - share_mem_t local_buffer = share_mem_t(shared_memory, cgh); - cgh.set_arg(num_params, local_buffer); - cgh.parallel_for(parallel_work_size, kernel_ptr); - }} else {{ - cgh.parallel_for(parallel_work_size, kernel_ptr); - }} - }}; - auto event = stream.submit(cgf); - xpu::profiler_record(kernel_name, event); - }} -// end sycl - static PyObject* launch(PyObject* self, PyObject* args) {{ - - int gridX, gridY, gridZ; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *kernel_metadata = NULL; - PyObject *launch_metadata = NULL; - PyObject *py_obj_stream; - PyObject *py_kernel; - - {" ".join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{fmt}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook {args_list})) {{ - return NULL; - }} - - // extract kernel metadata - int num_warps = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_warps")); - int num_ctas = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_ctas")); - int shared_memory = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "shared")); - int threads_per_warp = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "threads_per_warp")); - - // extract cluster dims - PyObject *clusterDim = PyObject_GetAttrString(kernel_metadata, "cluster_dims"); - if (!PyTuple_Check(kernel_metadata)) {{ - PyErr_SetString(PyExc_TypeError, "kernel_metadata.cluster_dims must be a tuple"); - return NULL; - }} - int clusterDimX = PyLong_AsLong(PyTuple_GetItem(clusterDim, 0)); - int clusterDimY = PyLong_AsLong(PyTuple_GetItem(clusterDim, 1)); - int clusterDimZ = PyLong_AsLong(PyTuple_GetItem(clusterDim, 2)); - // extract launch metadata - if (launch_enter_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - void * pStream = PyLong_AsVoidPtr(py_obj_stream); - //error check - if(pStream == nullptr || py_kernel == nullptr) return NULL; - - sycl::queue stream = *(static_cast(pStream)); - sycl::kernel* kernel_ptr = reinterpret_cast(PyCapsule_GetPointer(py_kernel, "kernel")); - if(kernel_ptr == nullptr) return NULL; - sycl::kernel kernel = *kernel_ptr; - - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""}); - - if(launch_exit_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - if (PyErr_Occurred()) {{ - return NULL; - }} - - // return None - Py_INCREF(Py_None); - return Py_None; - }} - - static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel - }}; - - static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_launcher\", - NULL, //documentation - -1, //size - ModuleMethods - }}; - - PyMODINIT_FUNC PyInit___triton_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; - }} - """ - return src - - -class XPULauncher: - - def __init__(self, src, metadata): # pylint: disable=unused-argument - ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} - constants = src.constants if hasattr(src, "constants") else {} - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} - src = make_launcher(constants, signature, ids) - mod = compile_module_from_src(src, "__triton_launcher") - self.launch = mod.launch - - def __call__(self, *args, **kwargs): - self.launch(*args, **kwargs) - - -class XPUDriver(DriverBase): - - def __init__(self): - self.launcher_cls = XPULauncher - - def __getattr__(self, name): - # Lazily initialize utils to avoid unnecessary XPU runtime invocations. - # See https://github.com/intel/intel-xpu-backend-for-triton/issues/624 - if name == "utils": - self.utils = XPUUtils() # pylint: disable=attribute-defined-outside-init - return self.utils - raise AttributeError - - def get_current_device(self): - return self.utils.get_current_device() - - def get_current_stream(self, device): # pylint: disable=unused-argument - return torch.xpu.current_stream().sycl_queue - - def get_current_target(self): - device = self.get_current_device() - dev_property = torch.xpu.get_device_capability(device) - warp_size = 32 - return GPUTarget("xpu", dev_property, warp_size) - - @staticmethod - def is_active(): - return torch.xpu.is_available() diff --git a/benchmarks/xetla_kernel/CMakeLists.txt b/benchmarks/xetla_kernel/CMakeLists.txt index 45c8f8faa3..61131f04b7 100644 --- a/benchmarks/xetla_kernel/CMakeLists.txt +++ b/benchmarks/xetla_kernel/CMakeLists.txt @@ -29,22 +29,13 @@ endif() add_library(xetla_kernel SHARED python_main.cpp) set_target_properties(xetla_kernel PROPERTIES PREFIX "") target_compile_options(xetla_kernel PRIVATE "-fPIC") -if(USE_IPEX) - target_compile_options(xetla_kernel PRIVATE "-fsycl") -else() - target_compile_options(xetla_kernel PRIVATE "-fsycl" "-fpreview-breaking-changes") -endif() +target_compile_options(xetla_kernel PRIVATE "-fsycl" "-fpreview-breaking-changes") target_compile_options(xetla_kernel PUBLIC "-DXETPP_NEW_XMAIN") target_link_options(xetla_kernel PRIVATE ${XETLA_KERNEL_FLAGS}) target_link_libraries(xetla_kernel PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) target_include_directories(xetla_kernel PUBLIC "${PYTHON_INCLUDE_DIRS}") target_include_directories(xetla_kernel PUBLIC "${XeTLALibrary_INCLUDE_DIR}") -if(USE_IPEX) - target_link_libraries(xetla_kernel PUBLIC "${TORCH_IPEX_LIBRARIES}") - target_include_directories(xetla_kernel PUBLIC "${TORCH_IPEX_INCLUDE_DIRS}") -endif() - add_subdirectory(softmax) add_subdirectory(gemm) add_subdirectory(stream_k_gemm) diff --git a/benchmarks/xetla_kernel/python_main.cpp b/benchmarks/xetla_kernel/python_main.cpp index 1af3840d0e..4b465d8619 100644 --- a/benchmarks/xetla_kernel/python_main.cpp +++ b/benchmarks/xetla_kernel/python_main.cpp @@ -4,27 +4,17 @@ #include "stream_k_gemm/stream_k_gemm.h" #include #include +#include #include #include -#ifdef USE_IPEX -#include -#else -#include -#endif - sycl::queue get_current_sycl_queue() { // submit kernel c10::impl::VirtualGuardImpl impl(at::DeviceType::XPU); c10::Stream stream = impl.getStream(impl.getDevice()); -#ifdef USE_IPEX - auto queue = xpu::get_queue_from_stream(stream); -#else auto xpu_stream = c10::xpu::XPUStream(stream); auto queue = xpu_stream.queue(); -#endif - return queue; } diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index d7659935ee..8f63141dfb 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -235,7 +235,7 @@ run_microbench_tests() { echo "****************************************************" echo "***** Running Triton Micro Benchmark tests *****" echo "****************************************************" - USE_IPEX=0 python $TRITON_PROJ/benchmarks/micro_benchmarks/run_benchmarks.py + python $TRITON_PROJ/benchmarks/micro_benchmarks/run_benchmarks.py } run_benchmark_softmax() {