Skip to content

Commit

Permalink
add shared dequant() cuda placeholder
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Mar 18, 2024
1 parent 4a95c9e commit 608f7c2
Showing 1 changed file with 58 additions and 3 deletions.
61 changes: 58 additions & 3 deletions hqq/kernels/hqq_aten_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
#define BLOCK_SIZE 256
#define SHARED_SIZE 512 //2xBLOCK_SIZE
#define BLOCK_SIZE 256 //~256
#define SHARED_SIZE 512 //~512

//Custom Dispatcher to support Float, Half, Bfloat16 since the in Aten doens't support bfp16: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Dispatch.h#L248
#define AT_DISPATCHER_CASE(...) \
Expand Down Expand Up @@ -198,6 +198,34 @@ __global__ void dequantize_2bit_u8_kernel(unsigned char* Wq_packed, scalar_t* sc
W_r[i + n*3] = (scalar_t((Wq_packed[i] & 0x03)) - zero[j])*scale[j]; //4th chunk
}


// //Shared
// template <typename scalar_t>
// __global__ void dequantize_2bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) {
// int i = blockIdx.x*blockDim.x + threadIdx.x;
// int n = h*w;
// int s = threadIdx.x;

// if(i>=n) return;

// __shared__ unsigned char shared[BLOCK_SIZE];
// __shared__ scalar_t shared_meta[BLOCK_SIZE][2];

// int j = i % w;
// shared[s] = Wq_packed[i];
// shared_meta[s][0] = zero[j];
// shared_meta[s][1] = scale[j];
// __syncthreads();


// W_r[i] = (scalar_t((shared[s] & 0xC0) >> 6) - shared_meta[s][0])*shared_meta[s][1]; //1st chunk
// W_r[i + n] = (scalar_t((shared[s] & 0x30) >> 4) - shared_meta[s][0])*shared_meta[s][1]; //2nd chunk
// W_r[i + n*2] = (scalar_t((shared[s] & 0x0C) >> 2) - shared_meta[s][0])*shared_meta[s][1]; //3rd chunk
// W_r[i + n*3] = (scalar_t((shared[s] & 0x03)) - shared_meta[s][0])*shared_meta[s][1]; //4th chunk
// }



torch::Tensor dequantize_2bit_u8(torch::Tensor Wq_packed, torch::Tensor scale, torch::Tensor zero)
{
CHECK_INPUT(Wq_packed);
Expand Down Expand Up @@ -285,9 +313,36 @@ __global__ void dequantize_1bit_u8_kernel(unsigned char* Wq_packed, scalar_t* sc
W_r[i + n*5] = (scalar_t((Wq_packed[i] & 0x04) >> 2) - zero[j])*scale[j]; //6th chunk
W_r[i + n*6] = (scalar_t((Wq_packed[i] & 0x02) >> 1) - zero[j])*scale[j]; //7th chunk
W_r[i + n*7] = (scalar_t((Wq_packed[i] & 0x01)) - zero[j])*scale[j]; //8th chunk

}

// //Shared
// template <typename scalar_t>
// __global__ void dequantize_1bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) {
// int i = blockIdx.x*blockDim.x + threadIdx.x;
// int s = threadIdx.x;
// int n = h*w;
// if(i>=n) return;

// __shared__ unsigned char shared[BLOCK_SIZE];
// __shared__ scalar_t shared_meta[BLOCK_SIZE][2];

// int j = i % w;
// shared[s] = Wq_packed[i];
// shared_meta[s][0] = zero[j];
// shared_meta[s][1] = scale[j];
// __syncthreads();

// W_r[i] = (scalar_t((shared[s] & 0x80) >> 7) - shared_meta[s][0])*shared_meta[s][1]; //1st chunk
// W_r[i + n] = (scalar_t((shared[s] & 0x40) >> 6) - shared_meta[s][0])*shared_meta[s][1]; //2nd chunk
// W_r[i + n*2] = (scalar_t((shared[s] & 0x20) >> 5) - shared_meta[s][0])*shared_meta[s][1]; //3rd chunk
// W_r[i + n*3] = (scalar_t((shared[s] & 0x10) >> 4) - shared_meta[s][0])*shared_meta[s][1]; //4th chunk
// W_r[i + n*4] = (scalar_t((shared[s] & 0x08) >> 3) - shared_meta[s][0])*shared_meta[s][1]; //5th chunk
// W_r[i + n*5] = (scalar_t((shared[s] & 0x04) >> 2) - shared_meta[s][0])*shared_meta[s][1]; //6th chunk
// W_r[i + n*6] = (scalar_t((shared[s] & 0x02) >> 1) - shared_meta[s][0])*shared_meta[s][1]; //7th chunk
// W_r[i + n*7] = (scalar_t((shared[s] & 0x01)) - shared_meta[s][0])*shared_meta[s][1]; //8th chunk
// }


torch::Tensor dequantize_1bit_u8(torch::Tensor Wq_packed, torch::Tensor scale, torch::Tensor zero)
{
CHECK_INPUT(Wq_packed);
Expand Down

0 comments on commit 608f7c2

Please sign in to comment.