diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 5f9973163732d3..f07678e62a2b83 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -1,197 +1,343 @@ -#include "llama.h" +#include #include #include #include +#include #include #include -static void print_usage(int, char ** argv) { - printf("\nexample usage:\n"); - printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); - printf("\n"); +#include "llama.h" + +// Add a message to `messages` and store its content in `owned_content` +static void add_message(const std::string &role, const std::string &text, + std::vector &messages, + std::vector> &owned_content) { + auto content = std::unique_ptr(new char[text.size() + 1]); + std::strcpy(content.get(), text.c_str()); + messages.push_back({role.c_str(), content.get()}); + owned_content.push_back(std::move(content)); } -int main(int argc, char ** argv) { - std::string model_path; - int ngl = 99; - int n_ctx = 2048; +// Function to apply the chat template and resize `formatted` if needed +static int apply_chat_template(const llama_model *model, + const std::vector &messages, + std::vector &formatted, + const bool append) { + int result = llama_chat_apply_template(model, nullptr, messages.data(), + messages.size(), append, + formatted.data(), formatted.size()); + if (result > static_cast(formatted.size())) { + formatted.resize(result); + result = llama_chat_apply_template(model, nullptr, messages.data(), + messages.size(), append, + formatted.data(), formatted.size()); + } - // parse command line arguments - for (int i = 1; i < argc; i++) { - try { - if (strcmp(argv[i], "-m") == 0) { - if (i + 1 < argc) { - model_path = argv[++i]; - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-c") == 0) { - if (i + 1 < argc) { - n_ctx = std::stoi(argv[++i]); - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-ngl") == 0) { - if (i + 1 < argc) { - ngl = std::stoi(argv[++i]); - } else { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); - return 1; - } - } catch (std::exception & e) { - fprintf(stderr, "error: %s\n", e.what()); - print_usage(argc, argv); - return 1; - } + return result; +} + +// Function to tokenize the prompt +static int tokenize_prompt(const llama_model *model, const std::string &prompt, + std::vector &prompt_tokens) { + const int n_prompt_tokens = -llama_tokenize( + model, prompt.c_str(), prompt.size(), NULL, 0, true, true); + prompt_tokens.resize(n_prompt_tokens); + if (llama_tokenize(model, prompt.c_str(), prompt.size(), + prompt_tokens.data(), prompt_tokens.size(), true, + true) < 0) { + GGML_ABORT("failed to tokenize the prompt\n"); + return -1; } - if (model_path.empty()) { - print_usage(argc, argv); + + return n_prompt_tokens; +} + +// Check if we have enough space in the context to evaluate this batch +static int check_context_size(const llama_context *ctx, + const llama_batch &batch) { + const int n_ctx = llama_n_ctx(ctx); + const int n_ctx_used = llama_get_kv_cache_used_cells(ctx); + if (n_ctx_used + batch.n_tokens > n_ctx) { + printf("\033[0m\n"); + fprintf(stderr, "context size exceeded\n"); return 1; } - // only print errors - llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) { - if (level >= GGML_LOG_LEVEL_ERROR) { - fprintf(stderr, "%s", text); - } - }, nullptr); - - // initialize the model - llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = ngl; + return 0; +} - llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); - if (!model) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); +// convert the token to a string +static int convert_token_to_string(const llama_model *model, + const llama_token token_id, + std::string &piece) { + char buf[256]; + int n = llama_token_to_piece(model, token_id, buf, sizeof(buf), 0, true); + if (n < 0) { + GGML_ABORT("failed to convert token to piece\n"); return 1; } - // initialize the context - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = n_ctx; - ctx_params.n_batch = n_ctx; + piece = std::string(buf, n); + return 0; +} + +static void print_word_and_concatenate_to_response(const std::string &piece, + std::string &response) { + printf("%s", piece.c_str()); + fflush(stdout); + response += piece; +} - llama_context * ctx = llama_new_context_with_model(model, ctx_params); - if (!ctx) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); +// helper function to evaluate a prompt and generate a response +static int generate(const llama_model *model, llama_sampler *smpl, + llama_context *ctx, const std::string &prompt, + std::string &response) { + std::vector prompt_tokens; + const int n_prompt_tokens = tokenize_prompt(model, prompt, prompt_tokens); + if (n_prompt_tokens < 0) { return 1; } - // initialize the sampler - llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params()); - llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1)); - llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f)); - llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + // prepare a batch for the prompt + llama_batch batch = + llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_token new_token_id; + while (true) { + check_context_size(ctx, batch); + if (llama_decode(ctx, batch)) { + GGML_ABORT("failed to decode\n"); + return 1; + } - // helper function to evaluate a prompt and generate a response - auto generate = [&](const std::string & prompt) { - std::string response; + // sample the next token, check is it an end of generation? + new_token_id = llama_sampler_sample(smpl, ctx, -1); + if (llama_token_is_eog(model, new_token_id)) { + break; + } - // tokenize the prompt - const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); - std::vector prompt_tokens(n_prompt_tokens); - if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { - GGML_ABORT("failed to tokenize the prompt\n"); + std::string piece; + if (convert_token_to_string(model, new_token_id, piece)) { + return 1; } - // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); - llama_token new_token_id; - while (true) { - // check if we have enough space in the context to evaluate this batch - int n_ctx = llama_n_ctx(ctx); - int n_ctx_used = llama_get_kv_cache_used_cells(ctx); - if (n_ctx_used + batch.n_tokens > n_ctx) { - printf("\033[0m\n"); - fprintf(stderr, "context size exceeded\n"); - exit(0); - } + print_word_and_concatenate_to_response(piece, response); - if (llama_decode(ctx, batch)) { - GGML_ABORT("failed to decode\n"); - } + // prepare the next batch with the sampled token + batch = llama_batch_get_one(&new_token_id, 1); + } - // sample the next token - new_token_id = llama_sampler_sample(smpl, ctx, -1); + return 0; +} - // is it an end of generation? - if (llama_token_is_eog(model, new_token_id)) { - break; - } +static void print_usage(int, const char **argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", + argv[0]); + printf("\n"); +} - // convert the token to a string, print it and add it to the response - char buf[256]; - int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); - if (n < 0) { - GGML_ABORT("failed to convert token to piece\n"); - } - std::string piece(buf, n); - printf("%s", piece.c_str()); - fflush(stdout); - response += piece; +static int parse_int_arg(const char *arg, int &value) { + char *end; + long val = std::strtol(arg, &end, 10); + if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) { + value = static_cast(val); + return 0; + } + + return 1; +} + +static int handle_model_path(const int argc, const char **argv, int &i, + std::string &model_path) { + if (i + 1 < argc) { + model_path = argv[++i]; + return 0; + } - // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + print_usage(argc, argv); + return 1; +} + +static int handle_n_ctx(const int argc, const char **argv, int &i, int &n_ctx) { + if (i + 1 < argc) { + if (parse_int_arg(argv[++i], n_ctx)) { + return 0; + } else { + fprintf(stderr, "error: invalid value for -c: %s\n", argv[i]); + print_usage(argc, argv); } + } else { + print_usage(argc, argv); + } - return response; - }; + return 1; +} - std::vector messages; - std::vector formatted(llama_n_ctx(ctx)); +static int handle_ngl(const int argc, const char **argv, int &i, int &ngl) { + if (i + 1 < argc) { + if (parse_int_arg(argv[++i], ngl)) { + return 0; + } else { + fprintf(stderr, "error: invalid value for -ngl: %s\n", argv[i]); + print_usage(argc, argv); + } + } else { + print_usage(argc, argv); + } + + return 1; +} + +static int parse_arguments(const int argc, const char **argv, + std::string &model_path, int &n_ctx, int &ngl) { + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "-m") == 0) { + if (handle_model_path(argc, argv, i, model_path)) { + return 1; + } + } else if (strcmp(argv[i], "-c") == 0) { + if (handle_n_ctx(argc, argv, i, n_ctx)) { + return 1; + } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (handle_ngl(argc, argv, i, ngl)) { + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } + + if (model_path.empty()) { + print_usage(argc, argv); + return 1; + } + + return 0; +} + +// The main chat loop where user inputs are processed and responses generated. +static int chat_loop(llama_model *model, llama_sampler *sampler, + llama_context *context, + std::vector &messages) { + std::vector> owned_content; + std::vector formatted(llama_n_ctx(context)); int prev_len = 0; + while (true) { - // get user input printf("\033[32m> \033[0m"); std::string user; std::getline(std::cin, user); + if (user.empty()) break; - if (user.empty()) { - break; - } - - // add the user input to the message list and format it - messages.push_back({"user", strdup(user.c_str())}); - int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); - if (new_len > (int)formatted.size()) { - formatted.resize(new_len); - new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); - } + add_message("user", user, messages, owned_content); + int new_len = apply_chat_template(model, messages, formatted, true); if (new_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); return 1; } - // remove previous messages to obtain the prompt to generate the response - std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); + std::string prompt(formatted.begin() + prev_len, + formatted.begin() + new_len); - // generate a response printf("\033[33m"); - std::string response = generate(prompt); + std::string response; + if (generate(model, sampler, context, prompt, response)) return 1; + printf("\n\033[0m"); - // add the response to the messages - messages.push_back({"assistant", strdup(response.c_str())}); - prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); + prev_len = apply_chat_template(model, messages, formatted, false); if (prev_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); return 1; } } - // free resources - for (auto & msg : messages) { - free(const_cast(msg.content)); + return 0; +} + +static void setup_logging() { + llama_log_set( + [](enum ggml_log_level level, const char *text, void *) { + if (level >= GGML_LOG_LEVEL_ERROR) fprintf(stderr, "%s", text); + }, + nullptr); +} + +// Initializes the model and returns a unique pointer to it. +static std::unique_ptr +initialize_model(const std::string &model_path, int ngl) { + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; + + auto model = std::unique_ptr( + llama_load_model_from_file(model_path.c_str(), model_params), + llama_free_model); + if (!model) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + } + + return model; +} + +// Initializes the context with the specified parameters. +static std::unique_ptr initialize_context( + llama_model *model, int n_ctx) { + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ctx; + + auto context = std::unique_ptr( + llama_new_context_with_model(model, ctx_params), llama_free); + if (!context) { + fprintf(stderr, "%s: error: failed to create the llama_context\n", + __func__); + } + + return context; +} + +// Initializes and configures the sampler. +static std::unique_ptr +initialize_sampler() { + auto sampler = + std::unique_ptr( + llama_sampler_chain_init(llama_sampler_chain_default_params()), + llama_sampler_free); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + + return sampler; +} + +int main(int argc, const char **argv) { + std::string model_path; + int ngl = 99; + int n_ctx = 2048; + if (parse_arguments(argc, argv, model_path, n_ctx, ngl)) { + return 1; + } + + setup_logging(); + + auto model = initialize_model(model_path, ngl); + if (!model) { + return 1; + } + + auto context = initialize_context(model.get(), n_ctx); + if (!context) { + return 1; + } + + auto sampler = initialize_sampler(); + std::vector messages; + if (chat_loop(model.get(), sampler.get(), context.get(), messages)) { + return 1; } - llama_sampler_free(smpl); - llama_free(ctx); - llama_free_model(model); return 0; }