Skip to content

Commit

Permalink
Merge branch 'master' into phi-1
Browse files Browse the repository at this point in the history
  • Loading branch information
teleprint-me committed Dec 25, 2023
2 parents 348d565 + b9f4795 commit b0583f7
Show file tree
Hide file tree
Showing 13 changed files with 636 additions and 399 deletions.
177 changes: 1 addition & 176 deletions .github/ISSUE_TEMPLATE/bug.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,179 +6,4 @@ assignees: ''

---

# Prerequisites

Please answer the following questions for yourself before submitting an issue.

- [ ] I am running the latest code. Development is very rapid so there are no tagged versions as of now.
- [ ] I carefully followed the [README.md](https://github.com/ggerganov/llama.cpp/blob/master/README.md).
- [ ] I [searched using keywords relevant to my issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/filtering-and-searching-issues-and-pull-requests) to make sure that I am creating a new issue that is not already open (or closed).
- [ ] I reviewed the [Discussions](https://github.com/ggerganov/llama.cpp/discussions), and have a new bug or useful enhancement to share.

# Expected Behavior

Please provide a detailed written description of what you were trying to do, and what you expected `llama.cpp` to do.

# Current Behavior

Please provide a detailed written description of what `llama.cpp` did, instead.

# Environment and Context

Please provide detailed information about your computer setup. This is important in case the issue is not reproducible except for under certain specific conditions.

* Physical (or virtual) hardware you are using, e.g. for Linux:

`$ lscpu`

* Operating System, e.g. for Linux:

`$ uname -a`

* SDK version, e.g. for Linux:

```
$ python3 --version
$ make --version
$ g++ --version
```

# Failure Information (for bugs)

Please help provide information about the failure / bug.

# Steps to Reproduce

Please provide detailed steps for reproducing the issue. We are not sitting in front of your screen, so the more detail the better.

1. step 1
2. step 2
3. step 3
4. etc.

# Failure Logs

Please include any relevant log snippets or files. If it works under one configuration but not under another, please provide logs for both configurations and their corresponding outputs so it is easy to see where behavior changes.

Also, please try to **avoid using screenshots** if at all possible. Instead, copy/paste the console output and use [Github's markdown](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) to cleanly format your logs for easy readability.

Example environment info:
```
llama.cpp$ git log | head -1
commit 2af23d30434a677c6416812eea52ccc0af65119c
llama.cpp$ lscpu | egrep "AMD|Flags"
Vendor ID: AuthenticAMD
Model name: AMD Ryzen Threadripper 1950X 16-Core Processor
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid amd_dcm aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb hw_pstate ssbd ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec xgetbv1 xsaves clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif overflow_recov succor smca sme sev
Virtualization: AMD-V
llama.cpp$ python3 --version
Python 3.10.9
llama.cpp$ pip list | egrep "torch|numpy|sentencepiece"
numpy 1.24.2
numpydoc 1.5.0
sentencepiece 0.1.97
torch 1.13.1
torchvision 0.14.1
llama.cpp$ make --version | head -1
GNU Make 4.3
$ md5sum ./models/65B/ggml-model-q4_0.bin
dbdd682cce80e2d6e93cefc7449df487 ./models/65B/ggml-model-q4_0.bin
```

Example run with the Linux command [perf](https://www.brendangregg.com/perf.html)
```
llama.cpp$ perf stat ./main -m ./models/65B/ggml-model-q4_0.bin -t 16 -n 1024 -p "Please close your issue when it has been answered."
main: seed = 1679149377
llama_model_load: loading model from './models/65B/ggml-model-q4_0.bin' - please wait ...
llama_model_load: n_vocab = 32000
llama_model_load: n_ctx = 512
llama_model_load: n_embd = 8192
llama_model_load: n_mult = 256
llama_model_load: n_head = 64
llama_model_load: n_layer = 80
llama_model_load: n_rot = 128
llama_model_load: f16 = 2
llama_model_load: n_ff = 22016
llama_model_load: n_parts = 8
llama_model_load: ggml ctx size = 41477.73 MB
llama_model_load: memory_size = 2560.00 MB, n_mem = 40960
llama_model_load: loading model part 1/8 from './models/65B/ggml-model-q4_0.bin'
llama_model_load: .......................................................................................... done
llama_model_load: model size = 4869.09 MB / num tensors = 723
llama_model_load: loading model part 2/8 from './models/65B/ggml-model-q4_0.bin.1'
llama_model_load: .......................................................................................... done
llama_model_load: model size = 4869.09 MB / num tensors = 723
llama_model_load: loading model part 3/8 from './models/65B/ggml-model-q4_0.bin.2'
llama_model_load: .......................................................................................... done
llama_model_load: model size = 4869.09 MB / num tensors = 723
llama_model_load: loading model part 4/8 from './models/65B/ggml-model-q4_0.bin.3'
llama_model_load: .......................................................................................... done
llama_model_load: model size = 4869.09 MB / num tensors = 723
llama_model_load: loading model part 5/8 from './models/65B/ggml-model-q4_0.bin.4'
llama_model_load: .......................................................................................... done
llama_model_load: model size = 4869.09 MB / num tensors = 723
llama_model_load: loading model part 6/8 from './models/65B/ggml-model-q4_0.bin.5'
llama_model_load: .......................................................................................... done
llama_model_load: model size = 4869.09 MB / num tensors = 723
llama_model_load: loading model part 7/8 from './models/65B/ggml-model-q4_0.bin.6'
llama_model_load: .......................................................................................... done
llama_model_load: model size = 4869.09 MB / num tensors = 723
llama_model_load: loading model part 8/8 from './models/65B/ggml-model-q4_0.bin.7'
llama_model_load: .......................................................................................... done
llama_model_load: model size = 4869.09 MB / num tensors = 723
system_info: n_threads = 16 / 32 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 |
main: prompt: 'Please close your issue when it has been answered.'
main: number of tokens in prompt = 11
1 -> ''
12148 -> 'Please'
3802 -> ' close'
596 -> ' your'
2228 -> ' issue'
746 -> ' when'
372 -> ' it'
756 -> ' has'
1063 -> ' been'
7699 -> ' answered'
29889 -> '.'
sampling parameters: temp = 0.800000, top_k = 40, top_p = 0.950000, repeat_last_n = 64, repeat_penalty = 1.300000
Please close your issue when it has been answered.
@duncan-donut: I'm trying to figure out what kind of "support" you need for this script and why, exactly? Is there a question about how the code works that hasn't already been addressed in one or more comments below this ticket, or are we talking something else entirely like some sorta bugfixing job because your server setup is different from mine??
I can understand if your site needs to be running smoothly and you need help with a fix of sorts but there should really be nothing wrong here that the code itself could not handle. And given that I'm getting reports about how it works perfectly well on some other servers, what exactly are we talking? A detailed report will do wonders in helping us get this resolved for ya quickly so please take your time and describe the issue(s) you see as clearly & concisely as possible!!
@duncan-donut: I'm not sure if you have access to cPanel but you could try these instructions. It is worth a shot! Let me know how it goes (or what error message, exactly!) when/if ya give that code a go? [end of text]
main: mem per token = 71159620 bytes
main: load time = 19309.95 ms
main: sample time = 168.62 ms
main: predict time = 223895.61 ms / 888.47 ms per token
main: total time = 246406.42 ms
Performance counter stats for './main -m ./models/65B/ggml-model-q4_0.bin -t 16 -n 1024 -p Please close your issue when it has been answered.':
3636882.89 msec task-clock # 14.677 CPUs utilized
13509 context-switches # 3.714 /sec
2436 cpu-migrations # 0.670 /sec
10476679 page-faults # 2.881 K/sec
13133115082869 cycles # 3.611 GHz (16.77%)
29314462753 stalled-cycles-frontend # 0.22% frontend cycles idle (16.76%)
10294402631459 stalled-cycles-backend # 78.39% backend cycles idle (16.74%)
23479217109614 instructions # 1.79 insn per cycle
# 0.44 stalled cycles per insn (16.76%)
2353072268027 branches # 647.002 M/sec (16.77%)
1998682780 branch-misses # 0.08% of all branches (16.76%)
247.802177522 seconds time elapsed
3618.573072000 seconds user
18.491698000 seconds sys
```
Please include information about your system, the steps to reproduce the bug, and the version of llama.cpp that you are using. If possible, please provide a minimal code example that reproduces the bug.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ if (LLAMA_CUBLAS)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)

if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard
# 60 == f16 CUDA intrinsics
Expand Down
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,15 @@ endif # LLAMA_BLIS

ifdef LLAMA_CUBLAS
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib
OBJS += ggml-cuda.o
MK_NVCCFLAGS = -use_fast_math
ifndef JETSON_EOL_MODULE_DETECT
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
endif # JETSON_EOL_MODULE_DETECT

ifdef LLAMA_DEBUG
MK_NVCCFLAGS += -lineinfo
endif

endif # LLAMA_DEBUG
ifdef LLAMA_CUDA_NVCC
NVCC = $(LLAMA_CUDA_NVCC)
else
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ as the main playground for developing new features for the [ggml](https://github
- [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek)
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)

**Multimodal models:**

Expand Down
86 changes: 85 additions & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def from_model_architecture(model_architecture):
return MixtralModel
if model_architecture == "PhiForCausalLM":
return PhiModel
if model_architecture == "PlamoForCausalLM":
return PlamoModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -224,6 +226,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI
if arch == "PlamoForCausalLM":
return gguf.MODEL_ARCH.PLAMO

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -1001,11 +1005,91 @@ def set_gguf_parameters(self):
self.gguf_writer.add_add_bos_token(False)


class PlamoModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()

def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]

self.gguf_writer.add_name("PLaMo")
self.gguf_writer.add_context_length(4096) # not in config.json
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])

def shuffle_attn_q_weight(self, data_torch):
assert data_torch.size() == (5120, 5120)
data_torch = data_torch.reshape(8, 5, 128, 5120)
data_torch = torch.permute(data_torch, (1, 0, 2, 3))
data_torch = torch.reshape(data_torch, (5120, 5120))
return data_torch

def shuffle_attn_output_weight(self, data_torch):
assert data_torch.size() == (5120, 5120)
data_torch = data_torch.reshape(5120, 8, 5, 128)
data_torch = torch.permute(data_torch, (0, 2, 1, 3))
data_torch = torch.reshape(data_torch, (5120, 5120))
return data_torch

def write_tensors(self):
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)

for name, data_torch in self.get_tensors():
if "self_attn.rotary_emb.inv_freq" in name:
continue

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

# shuffle for broadcasting of gqa in ggml_mul_mat
if new_name.endswith("attn_q.weight"):
data_torch = self.shuffle_attn_q_weight(data_torch)
elif new_name.endswith("attn_output.weight"):
data_torch = self.shuffle_attn_output_weight(data_torch)

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

data = data_torch.squeeze().numpy()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")

self.gguf_writer.add_tensor(new_name, data)


###### CONVERSION LOGIC ######


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
Expand Down
16 changes: 6 additions & 10 deletions ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ static void ggml_backend_registry_init(void) {
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);

int id = ggml_backend_registry_count;
size_t id = ggml_backend_registry_count;

ggml_backend_registry[id] = (struct ggml_backend_reg) {
/* .name = */ {0},
Expand Down Expand Up @@ -330,6 +330,8 @@ size_t ggml_backend_reg_find_by_name(const char * name) {
return i;
}
}

// not found
return SIZE_MAX;
}

Expand All @@ -340,15 +342,15 @@ ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str)
const char * params = strchr(backend_str, ':');
char backend_name[128];
if (params == NULL) {
strcpy(backend_name, backend_str);
snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
params = "";
} else {
strncpy(backend_name, backend_str, params - backend_str);
backend_name[params - backend_str] = '\0';
snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
params++;
}

size_t backend_i = ggml_backend_reg_find_by_name(backend_name);

if (backend_i == SIZE_MAX) {
fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
return NULL;
Expand Down Expand Up @@ -396,18 +398,12 @@ static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
}

static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");

memcpy((char *)tensor->data + offset, data, size);

GGML_UNUSED(buffer);
}

static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");

memcpy(data, (const char *)tensor->data + offset, size);

GGML_UNUSED(buffer);
Expand Down
Loading

0 comments on commit b0583f7

Please sign in to comment.