Skip to content

Commit

Permalink
Merge branch 'ggerganov:master' into gguf-model-template
Browse files Browse the repository at this point in the history
  • Loading branch information
teleprint-me authored Jul 28, 2024
2 parents 43eef8d + 6eeaeba commit 964ee4b
Show file tree
Hide file tree
Showing 62 changed files with 2,038 additions and 1,603 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,8 @@ jobs:
mkdir build
cd build
cmake .. -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=ON
cmake --build . --config Release -j $((${env:NUMBER_OF_PROCESSORS} - 1))
cmake --build . --config Release -j $((${env:NUMBER_OF_PROCESSORS} - 1)) -t ggml
cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS}
- name: Determine tag name
id: tag
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ build*
!docs/build.md
/libllama.so
/llama-*
/vulkan-shaders-gen
android-ndk-*
arm_neon.h
cmake-build-*
Expand Down
61 changes: 49 additions & 12 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ ifdef LLAMA_DEBUG
endif
else
MK_CPPFLAGS += -DNDEBUG
MK_CFLAGS += -O3
MK_CXXFLAGS += -O3
MK_NVCCFLAGS += -O3
MK_CFLAGS += -O3 -g
MK_CXXFLAGS += -O3 -g
MK_NVCCFLAGS += -O3 -g
endif

ifdef LLAMA_SANITIZE_THREAD
Expand Down Expand Up @@ -528,10 +528,21 @@ ifndef GGML_NO_ACCELERATE
endif
endif # GGML_NO_ACCELERATE

ifdef GGML_MUSA
CC := clang
CXX := clang++
GGML_CUDA := 1
MK_CPPFLAGS += -DGGML_USE_MUSA
endif

ifndef GGML_NO_OPENMP
MK_CPPFLAGS += -DGGML_USE_OPENMP
MK_CFLAGS += -fopenmp
MK_CXXFLAGS += -fopenmp
ifdef GGML_MUSA
MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp
MK_LDFLAGS += -L/usr/lib/llvm-10/lib
endif # GGML_MUSA
endif # GGML_NO_OPENMP

ifdef GGML_OPENBLAS
Expand Down Expand Up @@ -582,15 +593,27 @@ else
endif # GGML_CUDA_FA_ALL_QUANTS

ifdef GGML_CUDA
ifneq ('', '$(wildcard /opt/cuda)')
CUDA_PATH ?= /opt/cuda
ifdef GGML_MUSA
ifneq ('', '$(wildcard /opt/musa)')
CUDA_PATH ?= /opt/musa
else
CUDA_PATH ?= /usr/local/musa
endif

MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
else
CUDA_PATH ?= /usr/local/cuda
endif
ifneq ('', '$(wildcard /opt/cuda)')
CUDA_PATH ?= /opt/cuda
else
CUDA_PATH ?= /usr/local/cuda
endif

MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
MK_NVCCFLAGS += -use_fast_math
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
MK_NVCCFLAGS += -use_fast_math
endif # GGML_MUSA

OBJ_GGML += ggml/src/ggml-cuda.o
OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
Expand All @@ -600,9 +623,11 @@ ifdef LLAMA_FATAL_WARNINGS
MK_NVCCFLAGS += -Werror all-warnings
endif # LLAMA_FATAL_WARNINGS

ifndef GGML_MUSA
ifndef JETSON_EOL_MODULE_DETECT
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
endif # JETSON_EOL_MODULE_DETECT
endif # GGML_MUSA

ifdef LLAMA_DEBUG
MK_NVCCFLAGS += -lineinfo
Expand All @@ -615,8 +640,12 @@ endif # GGML_CUDA_DEBUG
ifdef GGML_CUDA_NVCC
NVCC = $(CCACHE) $(GGML_CUDA_NVCC)
else
NVCC = $(CCACHE) nvcc
endif #GGML_CUDA_NVCC
ifdef GGML_MUSA
NVCC = $(CCACHE) mcc
else
NVCC = $(CCACHE) nvcc
endif # GGML_MUSA
endif # GGML_CUDA_NVCC

ifdef CUDA_DOCKER_ARCH
MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
Expand Down Expand Up @@ -687,9 +716,15 @@ define NVCC_COMPILE
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
endef # NVCC_COMPILE
else
ifdef GGML_MUSA
define NVCC_COMPILE
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -c $< -o $@
endef # NVCC_COMPILE
else
define NVCC_COMPILE
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
endef # NVCC_COMPILE
endif # GGML_MUSA
endif # JETSON_EOL_MODULE_DETECT

ggml/src/ggml-cuda/%.o: \
Expand Down Expand Up @@ -944,6 +979,7 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1))
ifdef GGML_CUDA
$(info I NVCC: $(shell $(NVCC) --version | tail -n 1))
CUDA_VERSION := $(shell $(NVCC) --version | grep -oP 'release (\K[0-9]+\.[0-9])')
ifndef GGML_MUSA
ifeq ($(shell awk -v "v=$(CUDA_VERSION)" 'BEGIN { print (v < 11.7) }'),1)

ifndef CUDA_DOCKER_ARCH
Expand All @@ -953,6 +989,7 @@ endif # CUDA_POWER_ARCH
endif # CUDA_DOCKER_ARCH

endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1)
endif # GGML_MUSA
endif # GGML_CUDA
$(info )

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md)
| [BLAS](./docs/build.md#blas-build) | All |
| [BLIS](./docs/backend/BLIS.md) | All |
| [SYCL](./docs/backend/SYCL.md) | Intel and Nvidia GPU |
| [MUSA](./docs/build.md#musa) | Moore Threads GPU |
| [CUDA](./docs/build.md#cuda) | Nvidia GPU |
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
| [Vulkan](./docs/build.md#vulkan) | GPU |
Expand Down
5 changes: 5 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
else { invalid_param = true; }
return true;
}
if (arg == "--no-warmup") {
params.warmup = false;
return true;
}
#ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters
if (log_param_single_parse(argv[i])) {
Expand Down Expand Up @@ -1446,6 +1450,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "main infill", " --in-prefix-bos", "prefix BOS to user inputs, preceding the `--in-prefix` string" });
options.push_back({ "main infill", " --in-prefix STRING", "string to prefix user inputs with (default: empty)" });
options.push_back({ "main infill", " --in-suffix STRING", "string to suffix after user inputs with (default: empty)" });
options.push_back({ "main", " --no-warmup", "skip warming up the model with an empty run" });
options.push_back({ "server infill",
" --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });

Expand Down
28 changes: 28 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,34 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 10000.0)
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))

factor = rope_scaling.get("factor", 8.0)
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
assert low_freq_wavelen != high_freq_wavelen

rope_factors = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
rope_factors.append(1)
elif wavelen > low_freq_wavelen:
rope_factors.append(factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1 / ((1 - smooth) / factor + smooth))

self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))

super().prepare_tensors()

if self._experts is not None:
Expand Down
13 changes: 13 additions & 0 deletions docs/build.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/c
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
### MUSA
- Using `make`:
```bash
make GGML_MUSA=1
```
- Using `CMake`:
```bash
cmake -B build -DGGML_MUSA=ON
cmake --build build --config Release
```
### hipBLAS
This provides BLAS acceleration on HIP-supported AMD GPUs.
Expand Down
2 changes: 1 addition & 1 deletion examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
} else if (type == GGML_TYPE_I8) {
v = (float) *(int8_t *) &data[i];
} else {
GGML_ASSERT(false);
GGML_ABORT("fatal error");
}
printf("%12.4f", v);
sum += v;
Expand Down
4 changes: 2 additions & 2 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
}
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
exit(1); //GGML_ASSERT(false);
exit(1); //GGML_ABORT("fatal error");
}
if (m_params.verbosity > 1) {
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
Expand Down Expand Up @@ -176,7 +176,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
}
else if (e.values.size() != (size_t)src1->ne[0]) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
exit(1); //GGML_ASSERT(false);
exit(1); //GGML_ABORT("fatal error");
}
++e.ncall;
if (m_params.verbosity > 1) {
Expand Down
6 changes: 3 additions & 3 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ static const char * output_format_str(output_formats format) {
case JSON: return "json";
case MARKDOWN: return "md";
case SQL: return "sql";
default: GGML_ASSERT(!"invalid output format");
default: GGML_ABORT("invalid output format");
}
}

Expand All @@ -176,7 +176,7 @@ static const char * split_mode_str(llama_split_mode mode) {
case LLAMA_SPLIT_MODE_NONE: return "none";
case LLAMA_SPLIT_MODE_LAYER: return "layer";
case LLAMA_SPLIT_MODE_ROW: return "row";
default: GGML_ASSERT(!"invalid split mode");
default: GGML_ABORT("invalid split mode");
}
}

Expand Down Expand Up @@ -1326,7 +1326,7 @@ static std::unique_ptr<printer> create_printer(output_formats format) {
case SQL:
return std::unique_ptr<printer>(new sql_printer());
}
GGML_ASSERT(false);
GGML_ABORT("fatal error");
}

int main(int argc, char ** argv) {
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = peg_0;
}
else {
GGML_ASSERT(false);
GGML_ABORT("fatal error");
}
}

Expand Down
20 changes: 13 additions & 7 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ int main(int argc, char ** argv) {
// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data());
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());

FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
Expand Down Expand Up @@ -99,13 +99,16 @@ int main(int argc, char ** argv) {

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
std::vector<uint8_t> state_mem;

FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_state_set_data(ctx2, state_mem.data())) {
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
Expand Down Expand Up @@ -159,13 +162,16 @@ int main(int argc, char ** argv) {

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
std::vector<uint8_t> state_mem;

FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_state_set_data(ctx3, state_mem.data())) {
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx3);
llama_free_model(model);
Expand All @@ -182,7 +188,7 @@ int main(int argc, char ** argv) {
{
// save kv of seq 0
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
llama_free(ctx3);
Expand All @@ -196,7 +202,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : kv cache cleared\n", __func__);

// restore kv into seq 1
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
llama_free(ctx3);
Expand Down
2 changes: 1 addition & 1 deletion examples/tokenize/tokenize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) {
printf(">");
return;
}
GGML_ASSERT(false && "MultiByteToWideChar() failed in an unexpected way.");
GGML_ABORT("MultiByteToWideChar() failed in an unexpected way.");
}

LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr));
Expand Down
11 changes: 9 additions & 2 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,15 @@ else()
set(GGML_BLAS_VENDOR_DEFAULT "Generic")
endif()

if (CMAKE_CROSSCOMPILING)
set(GGML_NATIVE_DEFAULT OFF)
else()
set(GGML_NATIVE_DEFAULT ON)
endif()

# general
option(GGML_STATIC "ggml: static link libraries" OFF)
option(GGML_NATIVE "ggml: enable -march=native flag" ON)
option(GGML_NATIVE "ggml: enable -march=native flag" ${GGML_NATIVE_DEFAULT})
option(GGML_LTO "ggml: enable link time optimization" OFF)
option(GGML_CCACHE "ggml: use ccache if available" ON)

Expand All @@ -70,7 +76,7 @@ option(GGML_SANITIZE_ADDRESS "ggml: enable address sanitizer" OFF)
option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)

# instruction set specific
if (GGML_NATIVE)
if (GGML_NATIVE OR NOT GGML_NATIVE_DEFAULT)
set(INS_ENB OFF)
else()
set(INS_ENB ON)
Expand Down Expand Up @@ -107,6 +113,7 @@ set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)

option(GGML_CUDA "ggml: use CUDA" OFF)
option(GGML_MUSA "ggml: use MUSA" OFF)
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
Expand Down
Loading

0 comments on commit 964ee4b

Please sign in to comment.