Skip to content

Commit

Permalink
all xpu OpBuilder done, need more test
Browse files Browse the repository at this point in the history
  • Loading branch information
baodii committed Nov 5, 2023
1 parent e207e8d commit 9f2f591
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 30 deletions.
3 changes: 0 additions & 3 deletions op_builder/xpu/cpu_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def absolute_name(self):
return f'deepspeed.ops.adagrad.{self.NAME}_op'

def sources(self):
if self.build_for_cpu:
return ['csrc/adagrad/cpu_adagrad.cpp']

return ['csrc/adagrad/cpu_adagrad.cpp', 'csrc/common/custom_cuda_kernel.cu']

def include_paths(self):
Expand Down
28 changes: 28 additions & 0 deletions op_builder/xpu/post_process_cpu_adam.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,31 @@ find ./deepspeed/third-party/csrc -name "context.h" -exec sed -Ei "s:#include <A
# fix cublas transpos flag to mkl's
find ./deepspeed/third-party/csrc -name "context.h" -exec sed -i "s/CUBLAS_OP_T/oneapi::mkl::transpose::trans/g" {} +
find ./deepspeed/third-party/csrc -name "context.h" -exec sed -i "s/CUBLAS_OP_N/oneapi::mkl::transpose::nontrans/g" {} +

# add at::cuda::getCurrentCUDAStream and at::cuda::getStreamFromPool
patch ./deepspeed/third-party/csrc/includes/context.h << 'DIFF___'
@@ -16,6 +16,23 @@
#include <dpct/rng_utils.hpp>
#include "gemm_test.h"
+#include <ipex.h>
+namespace at {
+ namespace cuda {
+ dpct::queue_ptr getCurrentCUDAStream() {
+ auto device_type = c10::DeviceType::XPU;
+ c10::impl::VirtualGuardImpl impl(device_type);
+ c10::Stream c10_stream = impl.getStream(c10::Device(device_type));
+ auto& queue = xpu::get_queue_from_stream(c10_stream);
+ return &queue;
+ }
+
+ dpct::queue_ptr getStreamFromPool() {
+ // not implemented
+ return nullptr;
+ }
+ }
+}
#define WARP_SIZE 32
DIFF___
27 changes: 0 additions & 27 deletions op_builder/xpu/pre_process_cpu_adam.sh

This file was deleted.

0 comments on commit 9f2f591

Please sign in to comment.