Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] master from ggerganov:master #5

Merged
merged 4 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def from_model_architecture(model_architecture):
return QwenModel
if model_architecture == "MixtralForCausalLM":
return MixtralModel
if model_architecture == "PhiForCausalLM":
return Phi2Model
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -221,6 +223,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.QWEN
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -980,6 +984,24 @@ def write_tensors(self):
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)


class Phi2Model(Model):
def set_gguf_parameters(self):
block_count = self.hparams["n_layer"]

self.gguf_writer.add_name("Phi2")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_add_bos_token(False)


###### CONVERSION LOGIC ######


Expand Down
2 changes: 1 addition & 1 deletion examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ actor LlamaContext {
var pp_std: Double = 0
var tg_std: Double = 0

for r in 0..<nr {
for _ in 0..<nr {
// bench prompt processing

llama_batch_clear(&batch)
Expand Down
39 changes: 37 additions & 2 deletions examples/llama.swiftui/llama.swiftui/UI/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,56 @@ struct ContentView: View {
VStack {
DownloadButton(
llamaState: llamaState,
modelName: "TinyLlama-1.1B (Q4_0)",
modelName: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
)
.font(.system(size: 12))
.padding(.top, 4)
.frame(maxWidth: .infinity, alignment: .leading)

DownloadButton(
llamaState: llamaState,
modelName: "TinyLlama-1.1B (Q8_0)",
modelName: "TinyLlama-1.1B (Q8_0, 1.1 GiB)",
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true",
filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf"
)
.font(.system(size: 12))

DownloadButton(
llamaState: llamaState,
modelName: "TinyLlama-1.1B (F16, 2.2 GiB)",
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true",
filename: "tinyllama-1.1b-f16.gguf"
)
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)

DownloadButton(
llamaState: llamaState,
modelName: "Phi-2.7B (Q4_0, 1.6 GiB)",
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true",
filename: "phi-2-q4_0.gguf"
)
.font(.system(size: 12))

DownloadButton(
llamaState: llamaState,
modelName: "Phi-2.7B (Q8_0, 2.8 GiB)",
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true",
filename: "phi-2-q8_0.gguf"
)
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)

DownloadButton(
llamaState: llamaState,
modelName: "Mistral-7B-v0.1 (Q4_0, 3.8 GiB)",
modelUrl: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true",
filename: "mistral-7b-v0.1.Q4_0.gguf"
)
.font(.system(size: 12))

Button("Clear downloaded models") {
ContentView.cleanupModelCaches()
llamaState.cacheCleared = true
Expand Down
119 changes: 83 additions & 36 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
Expand All @@ -40,6 +41,7 @@
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
Expand Down Expand Up @@ -4998,7 +5000,16 @@ static __global__ void rope_neox(
const int ib = col / n_dims;
const int ic = col % n_dims;

const int i = row*ncols + ib*n_dims + ic/2;
if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;

float cur_rot = inv_ndims * ic - ib;
Expand Down Expand Up @@ -7057,6 +7068,7 @@ inline void ggml_cuda_op_upscale(

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_pad(
Expand All @@ -7073,6 +7085,7 @@ inline void ggml_cuda_op_pad(

(void) src1;
(void) dst;
(void) src1_dd;
}

inline void ggml_cuda_op_rms_norm(
Expand Down Expand Up @@ -7376,7 +7389,7 @@ inline void ggml_cuda_op_mul_mat_cublas(

const int compute_capability = g_compute_capabilities[id];

if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
Expand Down Expand Up @@ -8300,27 +8313,27 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
}

static __global__ void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
const half * src0_as_f16, const half * src1_as_f16, char * dst,
const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13,
int ne23,
int nb02, int nb03,
int nb12, int nb13,
int nb2, int nb3,
int r2, int r3) {
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
int64_t ne12, int64_t ne13,
int64_t ne23,
size_t nb02, size_t nb03,
size_t nb12, size_t nb13,
size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3) {
int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;

if (i13 >= ne13 || i12 >= ne12) {
return;
}

int i03 = i13 / r3;
int i02 = i12 / r2;
int64_t i03 = i13 / r3;
int64_t i02 = i12 / r2;

ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
}

static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down Expand Up @@ -8376,7 +8389,41 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);

size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);

half * dst_f16 = nullptr;
char * dst_t = nullptr;

cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;

// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;

const void * alpha = &alpha_f16;
const void * beta = &beta_f16;

if (dst->op_params[0] == GGML_PREC_DEFAULT) {
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
dst_t = (char *) dst_f16;

nbd2 /= sizeof(float) / sizeof(half);
nbd3 /= sizeof(float) / sizeof(half);
} else {
dst_t = (char *) dst_ddf;

cu_compute_type = CUBLAS_COMPUTE_32F;
cu_data_type = CUDA_R_32F;

alpha = &alpha_f32;
beta = &beta_f32;
}

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
Expand All @@ -8385,9 +8432,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

#if 0
// use cublasGemmEx
{
Expand All @@ -8397,12 +8441,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
int i02 = i12 / r2;

CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
CUBLAS_COMPUTE_16F,
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
Expand All @@ -8414,11 +8458,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13,
CUBLAS_COMPUTE_16F,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
// use cublasGemmBatchedEx
Expand All @@ -8435,24 +8479,24 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const

dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_as_f16, src1_as_f16, dst_f16,
src0_as_f16, src1_as_f16, dst_t,
ptrs_src, ptrs_dst,
ne12, ne13,
ne23,
nb02, nb03,
nb12, nb13,
dst->nb[2], dst->nb[3],
nbd2, nbd3,
r2, r3);
CUDA_CHECK(cudaGetLastError());

CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01,
ne23,
CUBLAS_COMPUTE_16F,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

if (ptrs_src_s != 0) {
Expand All @@ -8464,11 +8508,14 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
}
#endif

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);

ggml_cuda_pool_free(dst_f16, dst_as);
}

ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}

static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down
13 changes: 11 additions & 2 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1702,8 +1702,9 @@ kernel void kernel_rope(
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
if (ic < n_dims) {
const int64_t ib = 0;

// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
Expand All @@ -1722,6 +1723,14 @@ kernel void kernel_rope(

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int64_t i0 = ic;

device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
Expand Down
Loading
Loading