diff --git a/hqq/kernels/hqq_aten_cuda_kernel.cu b/hqq/kernels/hqq_aten_cuda_kernel.cu index c538d6d..4abe4e8 100755 --- a/hqq/kernels/hqq_aten_cuda_kernel.cu +++ b/hqq/kernels/hqq_aten_cuda_kernel.cu @@ -8,8 +8,8 @@ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;} -#define BLOCK_SIZE 256 -#define SHARED_SIZE 512 //2xBLOCK_SIZE +#define BLOCK_SIZE 256 //~256 +#define SHARED_SIZE 512 //~512 //Custom Dispatcher to support Float, Half, Bfloat16 since the in Aten doens't support bfp16: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Dispatch.h#L248 #define AT_DISPATCHER_CASE(...) \ @@ -198,6 +198,34 @@ __global__ void dequantize_2bit_u8_kernel(unsigned char* Wq_packed, scalar_t* sc W_r[i + n*3] = (scalar_t((Wq_packed[i] & 0x03)) - zero[j])*scale[j]; //4th chunk } + +// //Shared +// template +// __global__ void dequantize_2bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) { +// int i = blockIdx.x*blockDim.x + threadIdx.x; +// int n = h*w; +// int s = threadIdx.x; + +// if(i>=n) return; + +// __shared__ unsigned char shared[BLOCK_SIZE]; +// __shared__ scalar_t shared_meta[BLOCK_SIZE][2]; + +// int j = i % w; +// shared[s] = Wq_packed[i]; +// shared_meta[s][0] = zero[j]; +// shared_meta[s][1] = scale[j]; +// __syncthreads(); + + +// W_r[i] = (scalar_t((shared[s] & 0xC0) >> 6) - shared_meta[s][0])*shared_meta[s][1]; //1st chunk +// W_r[i + n] = (scalar_t((shared[s] & 0x30) >> 4) - shared_meta[s][0])*shared_meta[s][1]; //2nd chunk +// W_r[i + n*2] = (scalar_t((shared[s] & 0x0C) >> 2) - shared_meta[s][0])*shared_meta[s][1]; //3rd chunk +// W_r[i + n*3] = (scalar_t((shared[s] & 0x03)) - shared_meta[s][0])*shared_meta[s][1]; //4th chunk +// } + + + torch::Tensor dequantize_2bit_u8(torch::Tensor Wq_packed, torch::Tensor scale, torch::Tensor zero) { CHECK_INPUT(Wq_packed); @@ -285,9 +313,36 @@ __global__ void dequantize_1bit_u8_kernel(unsigned char* Wq_packed, scalar_t* sc W_r[i + n*5] = (scalar_t((Wq_packed[i] & 0x04) >> 2) - zero[j])*scale[j]; //6th chunk W_r[i + n*6] = (scalar_t((Wq_packed[i] & 0x02) >> 1) - zero[j])*scale[j]; //7th chunk W_r[i + n*7] = (scalar_t((Wq_packed[i] & 0x01)) - zero[j])*scale[j]; //8th chunk - } +// //Shared +// template +// __global__ void dequantize_1bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) { +// int i = blockIdx.x*blockDim.x + threadIdx.x; +// int s = threadIdx.x; +// int n = h*w; +// if(i>=n) return; + +// __shared__ unsigned char shared[BLOCK_SIZE]; +// __shared__ scalar_t shared_meta[BLOCK_SIZE][2]; + +// int j = i % w; +// shared[s] = Wq_packed[i]; +// shared_meta[s][0] = zero[j]; +// shared_meta[s][1] = scale[j]; +// __syncthreads(); + +// W_r[i] = (scalar_t((shared[s] & 0x80) >> 7) - shared_meta[s][0])*shared_meta[s][1]; //1st chunk +// W_r[i + n] = (scalar_t((shared[s] & 0x40) >> 6) - shared_meta[s][0])*shared_meta[s][1]; //2nd chunk +// W_r[i + n*2] = (scalar_t((shared[s] & 0x20) >> 5) - shared_meta[s][0])*shared_meta[s][1]; //3rd chunk +// W_r[i + n*3] = (scalar_t((shared[s] & 0x10) >> 4) - shared_meta[s][0])*shared_meta[s][1]; //4th chunk +// W_r[i + n*4] = (scalar_t((shared[s] & 0x08) >> 3) - shared_meta[s][0])*shared_meta[s][1]; //5th chunk +// W_r[i + n*5] = (scalar_t((shared[s] & 0x04) >> 2) - shared_meta[s][0])*shared_meta[s][1]; //6th chunk +// W_r[i + n*6] = (scalar_t((shared[s] & 0x02) >> 1) - shared_meta[s][0])*shared_meta[s][1]; //7th chunk +// W_r[i + n*7] = (scalar_t((shared[s] & 0x01)) - shared_meta[s][0])*shared_meta[s][1]; //8th chunk +// } + + torch::Tensor dequantize_1bit_u8(torch::Tensor Wq_packed, torch::Tensor scale, torch::Tensor zero) { CHECK_INPUT(Wq_packed);