diff --git a/op_builder/xpu/post_process_cpu_adam.sh b/op_builder/xpu/post_process_cpu_adam.sh new file mode 100644 index 000000000000..96e1e74a67c9 --- /dev/null +++ b/op_builder/xpu/post_process_cpu_adam.sh @@ -0,0 +1,6 @@ +# delete including torch cuda headers +find ./deepspeed/third-party/csrc -name "context.h" -exec sed -Ei "s:#include :// \0:g" {} + + +# fix cublas transpos flag to mkl's +find ./deepspeed/third-party/csrc -name "context.h" -exec sed -i "s/CUBLAS_OP_T/oneapi::mkl::transpose::trans/g" {} + +find ./deepspeed/third-party/csrc -name "context.h" -exec sed -i "s/CUBLAS_OP_N/oneapi::mkl::transpose::nontrans/g" {} + diff --git a/op_builder/xpu/pre_process_cpu_adam.sh b/op_builder/xpu/pre_process_cpu_adam.sh new file mode 100644 index 000000000000..eadb595230e4 --- /dev/null +++ b/op_builder/xpu/pre_process_cpu_adam.sh @@ -0,0 +1,27 @@ +# add at::cuda::getCurrentCUDAStream and at::cuda::getStreamFromPool +patch ./build/csrc/includes/context.h << 'DIFF___' +@@ -16,6 +16,23 @@ + #include + + #include "gemm_test.h" ++#include ++namespace at { ++ namespace cuda { ++ dpct::queue_ptr getCurrentCUDAStream() { ++ auto device_type = c10::DeviceType::XPU; ++ c10::impl::VirtualGuardImpl impl(device_type); ++ c10::Stream c10_stream = impl.getStream(c10::Device(device_type)); ++ auto& queue = xpu::get_queue_from_stream(c10_stream); ++ return &queue; ++ } ++ ++ dpct::queue_ptr getStreamFromPool() { ++ // not implemented ++ return nullptr; ++ } ++ } ++} + + #define WARP_SIZE 32 + +DIFF___